Spaces:
Runtime error
Runtime error
File size: 3,751 Bytes
7d1df38 f307fe5 7d1df38 3150e77 a8416ee 3150e77 214bd84 4dab50d 7d1df38 4dab50d f307fe5 a8416ee f307fe5 4dab50d 7d1df38 4dab50d a4c3b59 7d1df38 4dab50d 7d1df38 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
from huggingface_hub import hf_hub_url, cached_download
from PIL import Image
import os
import json
import glob
import random
from typing import Any, Dict, List
import torch
import torchvision
import wordsegment as ws
from virtex.config import Config
from virtex.factories import TokenizerFactory, PretrainingModelFactory, ImageTransformsFactory
from virtex.utils.checkpointing import CheckpointManager
CONFIG_PATH = "config.yaml"
MODEL_PATH = "checkpoint_last5.pth"
VALID_SUBREDDITS_PATH = "subreddit_list.json"
SAMPLES_PATH = "./samples/*.jpg"
class ImageLoader():
def __init__(self):
self.transformer = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Resize(256),
torchvision.transforms.CenterCrop(224),
torchvision.transforms.Normalize((.485, .456, .406), (.229, .224, .225))])
def load(self, im_path):
im = torch.FloatTensor(self.transformer(Image.open(im_path))).unsqueeze(0)
return {"image": im}
def raw_load(self, im_path):
im = torch.FloatTensor(Image.open(im_path))
return {"image": im}
def transform(self, image):
im = torch.FloatTensor(self.transformer(image)).unsqueeze(0)
return {"image": im}
def to_image(self, tensor):
return torchvision.transforms.ToPILImage()(tensor)
class VirTexModel():
def __init__(self):
self.config = Config(CONFIG_PATH)
ws.load()
self.device = 'cpu'
self.tokenizer = TokenizerFactory.from_config(self.config)
self.model = PretrainingModelFactory.from_config(self.config).to(self.device)
CheckpointManager(model=self.model).load("./checkpoint_last5.pth")
self.model.eval()
self.valid_subs = json.load(open(VALID_SUBREDDITS_PATH))
def predict(self, image_dict, sub_prompt = None, prompt = ""):
if sub_prompt is None:
subreddit_tokens = torch.tensor([self.model.sos_index], device=self.device).long()
else:
subreddit_tokens = torch.tensor([self.tokenizer.token_to_id(sub_prompt)], device=self.device).long()
predictions: List[Dict[str, Any]] = []
is_valid_subreddit = False
subreddit, rest_of_caption = "", ""
image_dict["decode_prompt"] = subreddit_tokens
while not is_valid_subreddit:
with torch.no_grad():
caption = self.model(image_dict)["predictions"][0].tolist()
if self.tokenizer.token_to_id("[SEP]") in caption:
sep_index = caption.index(self.tokenizer.token_to_id("[SEP]"))
caption[sep_index] = self.tokenizer.token_to_id("://")
caption = self.tokenizer.decode(caption)
if "://" in caption:
subreddit, rest_of_caption = caption.split("://")
subreddit = "".join(subreddit.split())
rest_of_caption = rest_of_caption.strip()
else:
subreddit, rest_of_caption = "", caption
is_valid_subreddit = True if sub_prompt is not None else subreddit in self.valid_subs
return subreddit, rest_of_caption
def download_files():
#download model files
download_files = [CONFIG_PATH, MODEL_PATH, VALID_SUBREDDITS_PATH]
for f in download_files:
fp = cached_download(hf_hub_url("zamborg/redcaps", filename=f))
os.system(f"cp {fp} ./{f}")
def get_samples():
return glob.glob(SAMPLES_PATH)
def get_rand_img(samples):
return samples[random.randint(0,len(samples)-1)]
|