virtex-redcaps / model.py
zamborg's picture
syntax fix
1df8b5f
raw
history blame
4.98 kB
from huggingface_hub import hf_hub_url, cached_download
from PIL import Image
import os
import json
import glob
import random
from typing import Any, Dict, List
import torch
import torchvision
import wordsegment as ws
from virtex.config import Config
from virtex.factories import TokenizerFactory, PretrainingModelFactory, ImageTransformsFactory
from virtex.utils.checkpointing import CheckpointManager
CONFIG_PATH = "config.yaml"
MODEL_PATH = "checkpoint_last5.pth"
VALID_SUBREDDITS_PATH = "subreddit_list.json"
SAMPLES_PATH = "./samples/*.jpg"
class ImageLoader():
def __init__(self):
self.transformer = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Resize(256),
torchvision.transforms.CenterCrop(224),
torchvision.transforms.Normalize((.485, .456, .406), (.229, .224, .225))])
self.show_size=500
def load(self, im_path):
im = torch.FloatTensor(self.transformer(Image.open(im_path))).unsqueeze(0)
return {"image": im}
def raw_load(self, im_path):
im = torch.FloatTensor(Image.open(im_path))
return {"image": im}
def transform(self, image):
im = torch.FloatTensor(self.transformer(image)).unsqueeze(0)
return {"image": im}
def text_transform(self, text):
# at present just lowercasing:
return text.lower()
def show_resize(self, image):
# ugh we need to do this manually cuz this is pytorch==0.8 not 1.9 lol
image = torchvision.transforms.functional.to_tensor(image)
x,y = image.shape[-2:]
ratio = float(self.show_size/max((x,y)))
image = torchvision.transforms.functional.resize(image, [int(x * ratio), int(y * ratio)])
return torchvision.transforms.functional.to_pil_image(image)
class VirTexModel():
def __init__(self):
self.config = Config(CONFIG_PATH)
ws.load()
self.device = 'cpu'
self.tokenizer = TokenizerFactory.from_config(self.config)
self.model = PretrainingModelFactory.from_config(self.config).to(self.device)
CheckpointManager(model=self.model).load("./checkpoint_last5.pth")
self.model.eval()
self.valid_subs = json.load(open(VALID_SUBREDDITS_PATH))
def predict(self, image_dict, sub_prompt = None, prompt = ""):
if sub_prompt is None:
subreddit_tokens = torch.tensor([self.model.sos_index], device=self.device).long()
else:
subreddit_tokens = " ".join(ws.segment(ws.clean(sub_prompt)))
subreddit_tokens = (
[self.model.sos_index] +
self.tokenizer.encode(subreddit_tokens) +
[self.tokenizer.token_to_id("[SEP]")]
)
subreddit_tokens = torch.tensor(subreddit_tokens, device=self.device).long()
if prompt is not "":
# at present prompts without subreddits will break without this change
# TODO FIX
if True: #sub_prompt is not None:
cap_tokens = self.tokenizer.encode(prompt)
cap_tokens = torch.tensor(cap_tokens, device=self.device).long()
subreddit_tokens = torch.cat([subreddit_tokens, cap_tokens])
predictions: List[Dict[str, Any]] = []
is_valid_subreddit = False
subreddit, rest_of_caption = "", ""
image_dict["decode_prompt"] = subreddit_tokens
while not is_valid_subreddit:
with torch.no_grad():
caption = self.model(image_dict)["predictions"][0].tolist()
if self.tokenizer.token_to_id("[SEP]") in caption:
sep_index = caption.index(self.tokenizer.token_to_id("[SEP]"))
caption[sep_index] = self.tokenizer.token_to_id("://")
caption = self.tokenizer.decode(caption)
if "://" in caption:
subreddit, rest_of_caption = caption.split("://")
subreddit = "".join(subreddit.split())
rest_of_caption = rest_of_caption.strip()
else:
subreddit, rest_of_caption = "", caption
is_valid_subreddit = True if sub_prompt is not None or prompt is not None else subreddit in self.valid_subs
return subreddit, rest_of_caption
def download_files():
#download model files
download_files = [CONFIG_PATH, MODEL_PATH, VALID_SUBREDDITS_PATH]
for f in download_files:
fp = cached_download(hf_hub_url("zamborg/redcaps", filename=f))
os.system(f"cp {fp} ./{f}")
def get_samples():
return glob.glob(SAMPLES_PATH)
def get_rand_img(samples):
i = random.randint(0,len(samples)-1)
return i, samples[i]