Spaces:
Runtime error
Runtime error
fixed normalization
Browse files
app.py
CHANGED
@@ -54,7 +54,7 @@ else:
|
|
54 |
|
55 |
image_dict = imageLoader.transform(image)
|
56 |
|
57 |
-
|
58 |
|
59 |
show = st.image(image)
|
60 |
show.image(image, "Your Image")
|
|
|
54 |
|
55 |
image_dict = imageLoader.transform(image)
|
56 |
|
57 |
+
# image = imageLoader.to_image(image_dict["image"].squeeze(0))
|
58 |
|
59 |
show = st.image(image)
|
60 |
show.image(image, "Your Image")
|
model.py
CHANGED
@@ -11,7 +11,7 @@ import torchvision
|
|
11 |
import wordsegment as ws
|
12 |
|
13 |
from virtex.config import Config
|
14 |
-
from virtex.factories import TokenizerFactory, PretrainingModelFactory
|
15 |
from virtex.utils.checkpointing import CheckpointManager
|
16 |
|
17 |
CONFIG_PATH = "config.yaml"
|
@@ -21,12 +21,17 @@ SAMPLES_PATH = "./samples/*.jpg"
|
|
21 |
|
22 |
class ImageLoader():
|
23 |
def __init__(self):
|
24 |
-
self.transformer = torchvision.transforms.Compose([
|
25 |
-
|
|
|
26 |
torchvision.transforms.ToTensor()])
|
27 |
def load(self, im_path):
|
28 |
im = torch.FloatTensor(self.transformer(Image.open(im_path))).unsqueeze(0)
|
29 |
return {"image": im}
|
|
|
|
|
|
|
|
|
30 |
def transform(self, image):
|
31 |
im = torch.FloatTensor(self.transformer(image)).unsqueeze(0)
|
32 |
return {"image": im}
|
|
|
11 |
import wordsegment as ws
|
12 |
|
13 |
from virtex.config import Config
|
14 |
+
from virtex.factories import TokenizerFactory, PretrainingModelFactory, ImageTransformsFactory
|
15 |
from virtex.utils.checkpointing import CheckpointManager
|
16 |
|
17 |
CONFIG_PATH = "config.yaml"
|
|
|
21 |
|
22 |
class ImageLoader():
|
23 |
def __init__(self):
|
24 |
+
self.transformer = torchvision.transforms.Compose([ImageTransformsFactory.create("smallest_resize"),
|
25 |
+
ImageTransformsFactory.create("center_crop"),
|
26 |
+
ImageTransformsFactory.create("normalize"),
|
27 |
torchvision.transforms.ToTensor()])
|
28 |
def load(self, im_path):
|
29 |
im = torch.FloatTensor(self.transformer(Image.open(im_path))).unsqueeze(0)
|
30 |
return {"image": im}
|
31 |
+
|
32 |
+
def raw_load(self, im_path):
|
33 |
+
im = torch.FloatTensor(Image.open(im_path)).unsqueeze(0)
|
34 |
+
return {"image": im}
|
35 |
def transform(self, image):
|
36 |
im = torch.FloatTensor(self.transformer(image)).unsqueeze(0)
|
37 |
return {"image": im}
|