from huggingface_hub import hf_hub_url, cached_download from PIL import Image import os import json import glob import random from typing import Any, Dict, List import torch import torchvision import wordsegment as ws from virtex.config import Config from virtex.factories import TokenizerFactory, PretrainingModelFactory, ImageTransformsFactory from virtex.utils.checkpointing import CheckpointManager CONFIG_PATH = "config.yaml" MODEL_PATH = "checkpoint_last5.pth" VALID_SUBREDDITS_PATH = "subreddit_list.json" SAMPLES_PATH = "./samples/*.jpg" class ImageLoader(): def __init__(self): self.transformer = torchvision.transforms.Compose([ImageTransformsFactory.create("smallest_resize"), ImageTransformsFactory.create("center_crop"), ImageTransformsFactory.create("normalize"), torchvision.transforms.ToTensor()]) def load(self, im_path): im = torch.FloatTensor(self.transformer(Image.open(im_path))).unsqueeze(0) return {"image": im} def raw_load(self, im_path): im = torch.FloatTensor(Image.open(im_path)).unsqueeze(0) return {"image": im} def transform(self, image): im = torch.FloatTensor(self.transformer(image)).unsqueeze(0) return {"image": im} def to_image(self, tensor): return torchvision.transforms.ToPILImage()(tensor) class VirTexModel(): def __init__(self): self.config = Config(CONFIG_PATH) ws.load() self.device = 'cpu' self.tokenizer = TokenizerFactory.from_config(self.config) self.model = PretrainingModelFactory.from_config(self.config).to(self.device) CheckpointManager(model=self.model).load("./checkpoint_last5.pth") self.model.eval() self.valid_subs = json.load(open(VALID_SUBREDDITS_PATH)) def predict(self, image_dict, sub_prompt = None, prompt = ""): if sub_prompt is None: subreddit_tokens = torch.tensor([self.model.sos_index], device=self.device).long() else: subreddit_tokens = torch.tensor([self.tokenizer.token_to_id(sub_prompt)], device=self.device).long() predictions: List[Dict[str, Any]] = [] is_valid_subreddit = False subreddit, rest_of_caption = "", "" image_dict["decode_prompt"] = subreddit_tokens while not is_valid_subreddit: with torch.no_grad(): caption = self.model(image_dict)["predictions"][0].tolist() if self.tokenizer.token_to_id("[SEP]") in caption: sep_index = caption.index(self.tokenizer.token_to_id("[SEP]")) caption[sep_index] = self.tokenizer.token_to_id("://") caption = self.tokenizer.decode(caption) if "://" in caption: subreddit, rest_of_caption = caption.split("://") subreddit = "".join(subreddit.split()) rest_of_caption = rest_of_caption.strip() else: subreddit, rest_of_caption = "", caption is_valid_subreddit = True if sub_prompt is not None else subreddit in self.valid_subs return subreddit, rest_of_caption def download_files(): #download model files download_files = [CONFIG_PATH, MODEL_PATH, VALID_SUBREDDITS_PATH] for f in download_files: fp = cached_download(hf_hub_url("zamborg/redcaps", filename=f)) os.system(f"cp {fp} ./{f}") def get_samples(): return glob.glob(SAMPLES_PATH) def get_rand_img(samples): return samples[random.randint(0,len(samples)-1)]