Spaces:
Runtime error
Runtime error
I think this works
Browse files- app.py +5 -2
- model.py +2 -0
- virtex/virtex/data/__init__.py +0 -1
- virtex/virtex/data/datasets/redcaps.py +0 -129
- virtex/virtex/factories.py +0 -2
app.py
CHANGED
@@ -42,11 +42,14 @@ if uploaded_image is None and submitted:
|
|
42 |
else:
|
43 |
image_file = sample_image if sample_image is not None else random_image
|
44 |
|
45 |
-
image = uploaded_image if uploaded_image is not None else Image.open()
|
46 |
|
47 |
image_dict = imageLoader.transform(image)
|
48 |
|
49 |
-
|
|
|
|
|
|
|
50 |
|
51 |
with st.spinner("Generating Caption"):
|
52 |
subreddit, caption = virtexModel.predict(image_dict)
|
|
|
42 |
else:
|
43 |
image_file = sample_image if sample_image is not None else random_image
|
44 |
|
45 |
+
image = uploaded_image if uploaded_image is not None else Image.open(image_file)
|
46 |
|
47 |
image_dict = imageLoader.transform(image)
|
48 |
|
49 |
+
image = imageLoader.to_image(image_dict["image"].squeeze(0))
|
50 |
+
|
51 |
+
show = st.image(image)
|
52 |
+
show.image(image, "Your Image")
|
53 |
|
54 |
with st.spinner("Generating Caption"):
|
55 |
subreddit, caption = virtexModel.predict(image_dict)
|
model.py
CHANGED
@@ -30,6 +30,8 @@ class ImageLoader():
|
|
30 |
def transform(self, image):
|
31 |
im = torch.FloatTensor(self.transformer(image)).unsqueeze(0)
|
32 |
return {"image": im}
|
|
|
|
|
33 |
|
34 |
class VirTexModel():
|
35 |
def __init__(self):
|
|
|
30 |
def transform(self, image):
|
31 |
im = torch.FloatTensor(self.transformer(image)).unsqueeze(0)
|
32 |
return {"image": im}
|
33 |
+
def to_image(self, tensor):
|
34 |
+
return torchvision.transforms.ToPILImage()(tensor)
|
35 |
|
36 |
class VirTexModel():
|
37 |
def __init__(self):
|
virtex/virtex/data/__init__.py
CHANGED
@@ -10,7 +10,6 @@ from .datasets.downstream import (
|
|
10 |
VOC07ClassificationDataset,
|
11 |
ImageDirectoryDataset,
|
12 |
)
|
13 |
-
from .datasets.redcaps import TarfileDataset
|
14 |
|
15 |
|
16 |
__all__ = [
|
|
|
10 |
VOC07ClassificationDataset,
|
11 |
ImageDirectoryDataset,
|
12 |
)
|
|
|
13 |
|
14 |
|
15 |
__all__ = [
|
virtex/virtex/data/datasets/redcaps.py
CHANGED
@@ -1,129 +0,0 @@
|
|
1 |
-
import glob
|
2 |
-
import os
|
3 |
-
import random
|
4 |
-
from typing import Callable
|
5 |
-
|
6 |
-
import numpy as np
|
7 |
-
import torch
|
8 |
-
from torch.utils.data import IterableDataset
|
9 |
-
import webdataset as wds
|
10 |
-
import wordsegment as ws
|
11 |
-
|
12 |
-
from virtex.data.tokenizers import SentencePieceBPETokenizer
|
13 |
-
from virtex.data import transforms as T
|
14 |
-
import virtex.utils.distributed as dist
|
15 |
-
|
16 |
-
ws.load()
|
17 |
-
|
18 |
-
|
19 |
-
class TarfileDataset(IterableDataset):
|
20 |
-
def __init__(
|
21 |
-
self,
|
22 |
-
data_root: str,
|
23 |
-
batch_size: int,
|
24 |
-
tokenizer: SentencePieceBPETokenizer,
|
25 |
-
image_transform: Callable = T.DEFAULT_IMAGE_TRANSFORM,
|
26 |
-
shuffle_buffer_size: int = 3000, # Set -1 to turn off shuffle.
|
27 |
-
max_caption_length: int = 50,
|
28 |
-
):
|
29 |
-
super().__init__()
|
30 |
-
|
31 |
-
self.tokenizer = tokenizer
|
32 |
-
self.image_transform = image_transform
|
33 |
-
self.max_caption_length = max_caption_length
|
34 |
-
|
35 |
-
self.padding_idx = tokenizer.token_to_id("<unk>")
|
36 |
-
self.sos_idx = tokenizer.token_to_id("[SOS]")
|
37 |
-
self.eos_idx = tokenizer.token_to_id("[EOS]")
|
38 |
-
self.sep_idx = tokenizer.token_to_id("[SEP]")
|
39 |
-
|
40 |
-
# Glob expand all paths in data root.
|
41 |
-
all_data_paths = []
|
42 |
-
for dr in data_root.split(" "):
|
43 |
-
all_data_paths.extend(glob.glob(dr))
|
44 |
-
|
45 |
-
# Deterministic shuffle across GPU process.
|
46 |
-
all_data_paths = sorted(all_data_paths)
|
47 |
-
random.Random(0).shuffle(all_data_paths)
|
48 |
-
|
49 |
-
# Shard the data paths as per gpu process.
|
50 |
-
all_data_paths = all_data_paths[dist.get_rank()::dist.get_world_size()]
|
51 |
-
|
52 |
-
self._dset = (
|
53 |
-
wds.WebDataset(all_data_paths)
|
54 |
-
.shuffle(shuffle_buffer_size, initial=shuffle_buffer_size)
|
55 |
-
.decode("rgb8", handler=wds.warn_and_continue)
|
56 |
-
.map(self._preprocess)
|
57 |
-
.batched(batch_size)
|
58 |
-
)
|
59 |
-
# Perform word-segmentation of all subreddit names (that's how the
|
60 |
-
# tokenizer was prepared). Subreddit names can be obtained from
|
61 |
-
# TAR file names: `{subreddit}_{year}_{index}.tar`.
|
62 |
-
if "redcaps" in data_root:
|
63 |
-
self.subreddit_segs = {
|
64 |
-
sub: " ".join(ws.segment(ws.clean(sub))) for sub in
|
65 |
-
set([os.path.basename(p).split("_")[0] for p in all_data_paths])
|
66 |
-
}
|
67 |
-
|
68 |
-
def _preprocess(self, annotation):
|
69 |
-
image, caption = annotation["jpg"], annotation["json"]["caption"]
|
70 |
-
|
71 |
-
# Transform image-caption pair and convert image from HWC to CHW format.
|
72 |
-
# Pass in caption to image_transform due to paired horizontal flip.
|
73 |
-
# Caption won't be tokenized/processed here.
|
74 |
-
image_caption = self.image_transform(image=image, caption=caption)
|
75 |
-
image, caption = image_caption["image"], image_caption["caption"]
|
76 |
-
image = np.transpose(image, (2, 0, 1))
|
77 |
-
|
78 |
-
# Tokenize caption.
|
79 |
-
_caption_tokens = self.tokenizer.encode(caption)
|
80 |
-
|
81 |
-
# Get subreddit name if it exists, and tokenize it. Only for RedCaps.
|
82 |
-
if "subreddit" in annotation["json"]:
|
83 |
-
subreddit = annotation["json"]["subreddit"].lower()
|
84 |
-
subreddit = self.subreddit_segs[subreddit]
|
85 |
-
|
86 |
-
# Add special [SEP] token after subreddit.
|
87 |
-
_subreddit_tokens = self.tokenizer.encode(subreddit) + [self.sep_idx]
|
88 |
-
else:
|
89 |
-
_subreddit_tokens = []
|
90 |
-
|
91 |
-
# Create forward and backward caption with subreddit token at the start.
|
92 |
-
caption_tokens = (
|
93 |
-
[self.sos_idx] + _subreddit_tokens + _caption_tokens + [self.eos_idx]
|
94 |
-
)[: self.max_caption_length]
|
95 |
-
|
96 |
-
noitpac_tokens = (
|
97 |
-
[self.eos_idx] + _subreddit_tokens + _caption_tokens[::-1] + [self.sos_idx]
|
98 |
-
)[: self.max_caption_length]
|
99 |
-
|
100 |
-
return image, caption_tokens, noitpac_tokens, len(caption_tokens)
|
101 |
-
|
102 |
-
def __len__(self):
|
103 |
-
raise NotImplementedError
|
104 |
-
|
105 |
-
def __iter__(self):
|
106 |
-
|
107 |
-
for batch in iter(self._dset):
|
108 |
-
# Collate the batch properly here. `image` and `caption_lengths`
|
109 |
-
# are already tensors.
|
110 |
-
image, caption_tokens, noitpac_tokens, caption_lengths = batch
|
111 |
-
|
112 |
-
# Pad `caption_tokens` and `masked_labels` up to this length.
|
113 |
-
caption_tokens = torch.nn.utils.rnn.pad_sequence(
|
114 |
-
[torch.tensor(c, dtype=torch.long) for c in caption_tokens],
|
115 |
-
batch_first=True,
|
116 |
-
padding_value=self.padding_idx,
|
117 |
-
)
|
118 |
-
noitpac_tokens = torch.nn.utils.rnn.pad_sequence(
|
119 |
-
[torch.tensor(c, dtype=torch.long) for c in noitpac_tokens],
|
120 |
-
batch_first=True,
|
121 |
-
padding_value=self.padding_idx,
|
122 |
-
)
|
123 |
-
caption_lengths = torch.tensor(caption_lengths, dtype=torch.long)
|
124 |
-
yield {
|
125 |
-
"image": torch.tensor(image, dtype=torch.float),
|
126 |
-
"caption_tokens": caption_tokens,
|
127 |
-
"noitpac_tokens": noitpac_tokens,
|
128 |
-
"caption_lengths": caption_lengths,
|
129 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
virtex/virtex/factories.py
CHANGED
@@ -194,8 +194,6 @@ class PretrainingDatasetFactory(Factory):
|
|
194 |
"masked_lm": vdata.MaskedLmDataset,
|
195 |
"token_classification": vdata.TokenClassificationDataset,
|
196 |
"multilabel_classification": vdata.MultiLabelClassificationDataset,
|
197 |
-
"virtex_web": vdata.TarfileDataset,
|
198 |
-
"miniclip_web": vdata.TarfileDataset,
|
199 |
}
|
200 |
|
201 |
@classmethod
|
|
|
194 |
"masked_lm": vdata.MaskedLmDataset,
|
195 |
"token_classification": vdata.TokenClassificationDataset,
|
196 |
"multilabel_classification": vdata.MultiLabelClassificationDataset,
|
|
|
|
|
197 |
}
|
198 |
|
199 |
@classmethod
|