File size: 1,271 Bytes
b7d7804 e225449 bc34c30 b7d7804 bc34c30 b7d7804 b876a4b b7d7804 1b17997 b7d7804 d6a700f b7d7804 f52bcf2 b7d7804 4a8794e bc34c30 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
import torch
import nltk
import io
import base64
from pytorch_pretrained_biggan import BigGAN, one_hot_from_names, truncated_noise_sample
class PreTrainedPipeline():
def __init__(self, path=""):
"""
Initialize model
"""
nltk.download('wordnet')
self.model = BigGAN.from_pretrained(path)
self.truncation = 0.1
def __call__(self, inputs: str) -> str:
"""
Args:
inputs (:obj:`str`):
a string containing some text
Return:
A :obj:`np.array`. A np.array containing the image information.
"""
class_vector = one_hot_from_names([inputs], batch_size=1)
if type(class_vector) == type(None):
raise ValueError("Input is not in ImageNet")
noise_vector = truncated_noise_sample(truncation=self.truncation, batch_size=1)
noise_vector = torch.from_numpy(noise_vector)
class_vector = torch.from_numpy(class_vector)
with torch.no_grad():
output = self.model(noise_vector, class_vector, self.truncation)
img = transforms.ToPILImage(output[0])
buf = io.BytesIO()
img.save(buf, format="JPEG")
return base64.encodebytes(buf.getvalue()).decode('utf-8')
|