osanseviero
commited on
Commit
•
8444104
1
Parent(s):
58caaa4
Update pipeline.py
Browse files- pipeline.py +8 -2
pipeline.py
CHANGED
@@ -2,6 +2,7 @@ import torch
|
|
2 |
import nltk
|
3 |
import io
|
4 |
import base64
|
|
|
5 |
from torchvision import transforms
|
6 |
|
7 |
from pytorch_pretrained_biggan import BigGAN, one_hot_from_names, truncated_noise_sample
|
@@ -11,8 +12,13 @@ class PreTrainedPipeline():
|
|
11 |
"""
|
12 |
Initialize model
|
13 |
"""
|
14 |
-
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
16 |
self.truncation = 0.1
|
17 |
|
18 |
def __call__(self, inputs: str):
|
|
|
2 |
import nltk
|
3 |
import io
|
4 |
import base64
|
5 |
+
import shutil
|
6 |
from torchvision import transforms
|
7 |
|
8 |
from pytorch_pretrained_biggan import BigGAN, one_hot_from_names, truncated_noise_sample
|
|
|
12 |
"""
|
13 |
Initialize model
|
14 |
"""
|
15 |
+
try:
|
16 |
+
self.model = BigGAN.from_pretrained(path)
|
17 |
+
except (IOError, OSError):
|
18 |
+
directory = "/data/corpora"
|
19 |
+
shutil.rmtree(directory)
|
20 |
+
nltk.download('wordnet')
|
21 |
+
self.model = BigGAN.from_pretrained(path)
|
22 |
self.truncation = 0.1
|
23 |
|
24 |
def __call__(self, inputs: str):
|