Spaces:
Runtime error
Runtime error
File size: 5,708 Bytes
5d3b8a6 7d1df38 f307fe5 7d1df38 65193db a8416ee 3150e77 214bd84 09da12b 214bd84 4dab50d 65193db 4dab50d 79c7b01 f307fe5 a8416ee f307fe5 79c7b01 4dab50d 65193db 4dab50d 79c7b01 ab30850 79c7b01 09da12b 1df8b5f 79c7b01 7d1df38 7cc986f 7d1df38 b80df5c 16484d3 b80df5c 250dc27 b80df5c c838395 ed768de b58ad35 11ab28e b58ad35 d92334f b80df5c 7d1df38 4dab50d 7d1df38 11ab28e 7d1df38 defbed4 7d1df38 2fd38cf 5d3b8a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import streamlit as st
from huggingface_hub import hf_hub_url, cached_download
from PIL import Image
import os
import json
import glob
import random
from typing import Any, Dict, List
import torch
import torchvision
import wordsegment as ws
from virtex.config import Config
from virtex.factories import TokenizerFactory, PretrainingModelFactory, ImageTransformsFactory
from virtex.utils.checkpointing import CheckpointManager
CONFIG_PATH = "config.yaml"
MODEL_PATH = "checkpoint_last5.pth"
VALID_SUBREDDITS_PATH = "subreddit_list.json"
SAMPLES_PATH = "./samples/*.jpg"
class ImageLoader():
def __init__(self):
self.image_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Resize(256),
torchvision.transforms.CenterCrop(224),
torchvision.transforms.Normalize((.485, .456, .406), (.229, .224, .225))])
self.show_size=500
def load(self, im_path):
im = torch.FloatTensor(self.image_transform(Image.open(im_path))).unsqueeze(0)
return {"image": im}
def raw_load(self, im_path):
im = torch.FloatTensor(Image.open(im_path))
return {"image": im}
def transform(self, image):
im = torch.FloatTensor(self.image_transform(image)).unsqueeze(0)
return {"image": im}
def text_transform(self, text):
# at present just lowercasing:
return text.lower()
def show_resize(self, image):
# ugh we need to do this manually cuz this is pytorch==0.8 not 1.9 lol
image = torchvision.transforms.functional.to_tensor(image)
x,y = image.shape[-2:]
ratio = float(self.show_size/max((x,y)))
image = torchvision.transforms.functional.resize(image, [int(x * ratio), int(y * ratio)])
return torchvision.transforms.functional.to_pil_image(image)
class VirTexModel():
def __init__(self):
self.config = Config(CONFIG_PATH)
ws.load()
self.device = 'cpu'
self.tokenizer = TokenizerFactory.from_config(self.config)
self.model = PretrainingModelFactory.from_config(self.config).to(self.device)
CheckpointManager(model=self.model).load(MODEL_PATH)
self.model.eval()
self.valid_subs = json.load(open(VALID_SUBREDDITS_PATH))
def predict(self, image_dict, sub_prompt = None, prompt = ""):
if sub_prompt is None:
subreddit_tokens = torch.tensor([self.model.sos_index], device=self.device).long()
else:
subreddit_tokens = " ".join(ws.segment(ws.clean(sub_prompt)))
subreddit_tokens = (
[self.model.sos_index] +
self.tokenizer.encode(subreddit_tokens) +
[self.tokenizer.token_to_id("[SEP]")]
)
subreddit_tokens = torch.tensor(subreddit_tokens, device=self.device).long()
if prompt is not "":
# at present prompts without subreddits will break without this change
# TODO FIX
cap_tokens = self.tokenizer.encode(prompt)
cap_tokens = torch.tensor(cap_tokens, device=self.device).long()
subreddit_tokens = subreddit_tokens if sub_prompt is not None else torch.tensor(
(
[self.model.sos_index] +
self.tokenizer.encode("pics") +
[self.tokenizer.token_to_id("[SEP]")]
), device = self.device).long()
subreddit_tokens = torch.cat(
[
subreddit_tokens,
torch.tensor([self.tokenizer.token_to_id("[SEP]")], device=self.device).long(),
cap_tokens
])
predictions: List[Dict[str, Any]] = []
is_valid_subreddit = False
subreddit, rest_of_caption = "", ""
image_dict["decode_prompt"] = subreddit_tokens
while not is_valid_subreddit:
with torch.no_grad():
caption = self.model(image_dict)["predictions"][0].tolist()
if self.tokenizer.token_to_id("[SEP]") in caption:
sep_index = caption.index(self.tokenizer.token_to_id("[SEP]"))
caption[sep_index] = self.tokenizer.token_to_id("://")
caption = self.tokenizer.decode(caption)
if "://" in caption:
subreddit, rest_of_caption = caption.split("://")
subreddit = "".join(subreddit.split())
rest_of_caption = rest_of_caption.strip()
else:
subreddit, rest_of_caption = "", caption
is_valid_subreddit = subreddit in self.valid_subs
return subreddit, rest_of_caption
def download_files():
#download model files
download_files = [CONFIG_PATH, MODEL_PATH, VALID_SUBREDDITS_PATH]
for f in download_files:
fp = cached_download(hf_hub_url("zamborg/redcaps", filename=f))
os.system(f"cp {fp} ./{f}")
def get_samples():
return glob.glob(SAMPLES_PATH)
def get_rand_idx(samples):
return random.randint(0,len(samples)-1)
@st.cache(allow_output_mutation=True) # allow mutation to update nucleus size
def create_objects():
sample_images = get_samples()
virtexModel = VirTexModel()
imageLoader = ImageLoader()
valid_subs = json.load(open(VALID_SUBREDDITS_PATH))
valid_subs.insert(0, None)
return virtexModel, imageLoader, sample_images, valid_subs
|