osanseviero
commited on
Commit
·
740f729
1
Parent(s):
af0640a
Update pipeline.py
Browse files- pipeline.py +7 -10
pipeline.py
CHANGED
@@ -16,29 +16,26 @@ class PreTrainedPipeline():
|
|
16 |
self.model = BigGAN.from_pretrained(path)
|
17 |
self.truncation = 0.1
|
18 |
|
19 |
-
|
20 |
-
def __call__(self, inputs: str) -> str:
|
21 |
"""
|
22 |
Args:
|
23 |
inputs (:obj:`str`):
|
24 |
a string containing some text
|
25 |
Return:
|
26 |
-
A :obj:`
|
27 |
"""
|
28 |
class_vector = one_hot_from_names([inputs], batch_size=1)
|
29 |
if type(class_vector) == type(None):
|
30 |
raise ValueError("Input is not in ImageNet")
|
31 |
-
|
32 |
noise_vector = truncated_noise_sample(truncation=self.truncation, batch_size=1)
|
33 |
-
|
34 |
noise_vector = torch.from_numpy(noise_vector)
|
35 |
class_vector = torch.from_numpy(class_vector)
|
36 |
-
|
37 |
with torch.no_grad():
|
38 |
-
output = self.model(noise_vector, class_vector,
|
39 |
|
40 |
img = transforms.ToPILImage()(output[0])
|
41 |
-
|
42 |
-
img.save(
|
|
|
43 |
|
44 |
-
return
|
|
|
16 |
self.model = BigGAN.from_pretrained(path)
|
17 |
self.truncation = 0.1
|
18 |
|
19 |
+
def __call__(self, inputs: str):
|
|
|
20 |
"""
|
21 |
Args:
|
22 |
inputs (:obj:`str`):
|
23 |
a string containing some text
|
24 |
Return:
|
25 |
+
A :obj:`PIL.Image`. The raw image representation as PIL.
|
26 |
"""
|
27 |
class_vector = one_hot_from_names([inputs], batch_size=1)
|
28 |
if type(class_vector) == type(None):
|
29 |
raise ValueError("Input is not in ImageNet")
|
|
|
30 |
noise_vector = truncated_noise_sample(truncation=self.truncation, batch_size=1)
|
|
|
31 |
noise_vector = torch.from_numpy(noise_vector)
|
32 |
class_vector = torch.from_numpy(class_vector)
|
|
|
33 |
with torch.no_grad():
|
34 |
+
output = self.model(noise_vector, class_vector, truncation)
|
35 |
|
36 |
img = transforms.ToPILImage()(output[0])
|
37 |
+
buffer = BytesIO()
|
38 |
+
img.save(buffer, format="JPEG")
|
39 |
+
img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
40 |
|
41 |
+
return img_str
|