import torch import nltk import io import base64 from pytorch_pretrained_biggan import BigGAN, one_hot_from_names, truncated_noise_sample class PreTrainedPipeline(): def __init__(self, path=""): """ Initialize model """ nltk.download('wordnet') self.model = BigGAN.from_pretrained(path) self.truncation = 0.1 def __call__(self, inputs: str) -> str: """ Args: inputs (:obj:`str`): a string containing some text Return: A :obj:`np.array`. A np.array containing the image information. """ class_vector = one_hot_from_names([inputs], batch_size=1) if type(class_vector) == type(None): raise ValueError("Input is not in ImageNet") noise_vector = truncated_noise_sample(truncation=self.truncation, batch_size=1) noise_vector = torch.from_numpy(noise_vector) class_vector = torch.from_numpy(class_vector) with torch.no_grad(): output = self.model(noise_vector, class_vector, self.truncation) img = transforms.ToPILImage()(output[0]) buf = io.BytesIO() img.save(buf, format="JPEG") return base64.encodebytes(buf.getvalue()).decode('utf-8')