Spaces:
Runtime error
Runtime error
Samuel Stevens
commited on
Commit
·
6e5adf0
1
Parent(s):
d4005aa
add open-domain classification back
Browse files- .gitattributes +1 -1
- app.py +115 -112
- make_txt_embedding.py +21 -0
- txt_emb_species.json +3 -0
- txt_emb_species.npy +3 -0
.gitattributes
CHANGED
@@ -34,6 +34,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
|
37 |
-
|
38 |
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
39 |
*.png filter=lfs diff=lfs merge=lfs -text
|
|
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
|
37 |
+
*.json filter=lfs diff=lfs merge=lfs -text
|
38 |
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
39 |
*.png filter=lfs diff=lfs merge=lfs -text
|
app.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
import json
|
2 |
import os
|
3 |
|
@@ -8,15 +10,18 @@ import torch.nn.functional as F
|
|
8 |
from open_clip import create_model, get_tokenizer
|
9 |
from torchvision import transforms
|
10 |
|
11 |
-
import lib
|
12 |
from templates import openai_imagenet_template
|
13 |
|
14 |
hf_token = os.getenv("HF_TOKEN")
|
15 |
|
16 |
model_str = "hf-hub:imageomics/bioclip"
|
17 |
tokenizer_str = "ViT-B-16"
|
18 |
-
|
19 |
-
txt_emb_npy = "
|
|
|
|
|
|
|
|
|
20 |
|
21 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
22 |
|
@@ -33,12 +38,12 @@ preprocess_img = transforms.Compose(
|
|
33 |
|
34 |
ranks = ("Kingdom", "Phylum", "Class", "Order", "Family", "Genus", "Species")
|
35 |
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
zero_shot_examples = [
|
43 |
[
|
44 |
"examples/Ursus-arctos.jpeg",
|
@@ -73,6 +78,10 @@ zero_shot_examples = [
|
|
73 |
]
|
74 |
|
75 |
|
|
|
|
|
|
|
|
|
76 |
@torch.no_grad()
|
77 |
def get_txt_features(classnames, templates):
|
78 |
all_features = []
|
@@ -102,52 +111,38 @@ def zero_shot_classification(img, cls_str: str) -> dict[str, float]:
|
|
102 |
|
103 |
|
104 |
@torch.no_grad()
|
105 |
-
def open_domain_classification(img, rank: int) ->
|
106 |
"""
|
107 |
-
Predicts from the
|
|
|
|
|
108 |
"""
|
109 |
img = preprocess_img(img).to(device)
|
110 |
img_features = model.encode_image(img.unsqueeze(0))
|
111 |
img_features = F.normalize(img_features, dim=-1)
|
112 |
|
113 |
-
|
114 |
-
|
115 |
-
name = []
|
116 |
-
for _ in range(rank + 1):
|
117 |
-
children = tuple(zip(*name_lookup.children(name)))
|
118 |
-
if not children:
|
119 |
-
break
|
120 |
-
values, indices = children
|
121 |
-
txt_features = txt_emb[:, indices].to(device)
|
122 |
-
logits = (model.logit_scale.exp() * img_features @ txt_features).view(-1)
|
123 |
-
|
124 |
-
probs = F.softmax(logits, dim=0).to("cpu").tolist()
|
125 |
-
parent = " ".join(name)
|
126 |
-
outputs.append(
|
127 |
-
{f"{parent} {value}": prob for value, prob in zip(values, probs)}
|
128 |
-
)
|
129 |
-
|
130 |
-
top = values[logits.argmax()]
|
131 |
-
name.append(top)
|
132 |
|
133 |
-
|
134 |
-
|
|
|
|
|
|
|
|
|
135 |
|
136 |
-
|
|
|
|
|
|
|
137 |
|
|
|
138 |
|
139 |
-
|
140 |
-
return [
|
141 |
-
gr.Label(
|
142 |
-
num_top_classes=5, label=rank, show_label=True, visible=(6 - i <= choice)
|
143 |
-
)
|
144 |
-
for i, rank in enumerate(reversed(ranks))
|
145 |
-
]
|
146 |
|
147 |
|
148 |
-
def
|
149 |
-
|
150 |
-
return lib.TaxonomicTree.from_dict(json.load(fd))
|
151 |
|
152 |
|
153 |
if __name__ == "__main__":
|
@@ -161,8 +156,9 @@ if __name__ == "__main__":
|
|
161 |
|
162 |
tokenizer = get_tokenizer(tokenizer_str)
|
163 |
|
164 |
-
|
165 |
-
|
|
|
166 |
|
167 |
done = txt_emb.any(axis=0).sum().item()
|
168 |
total = txt_emb.shape[1]
|
@@ -173,69 +169,76 @@ if __name__ == "__main__":
|
|
173 |
with gr.Blocks() as app:
|
174 |
img_input = gr.Image(height=512)
|
175 |
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
# [img_input, *open_domain_outputs], flagging_dir="logs/flagged"
|
204 |
-
# )
|
205 |
-
# open_domain_flag_btn.click(
|
206 |
-
# lambda *args: open_domain_callback.flag(args),
|
207 |
-
# [img_input, *open_domain_outputs],
|
208 |
-
# None,
|
209 |
-
# preprocess=False,
|
210 |
-
# )
|
211 |
-
|
212 |
-
# with gr.Tab("Zero-Shot"):
|
213 |
-
with gr.Row():
|
214 |
-
with gr.Column():
|
215 |
-
classes_txt = gr.Textbox(
|
216 |
-
placeholder="Canis familiaris (dog)\nFelis catus (cat)\n...",
|
217 |
-
lines=3,
|
218 |
-
label="Classes",
|
219 |
-
show_label=True,
|
220 |
-
info="Use taxonomic names where possible; include common names if possible.",
|
221 |
)
|
222 |
-
zero_shot_btn = gr.Button("Submit", variant="primary")
|
223 |
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
outputs=[zero_shot_output],
|
237 |
)
|
238 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
239 |
zero_shot_callback = gr.HuggingFaceDatasetSaver(
|
240 |
hf_token, "imageomics/bioclip-demo-zero-shot-mistakes", private=True
|
241 |
)
|
@@ -249,15 +252,15 @@ if __name__ == "__main__":
|
|
249 |
preprocess=False,
|
250 |
)
|
251 |
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
|
262 |
zero_shot_btn.click(
|
263 |
fn=zero_shot_classification,
|
|
|
1 |
+
import collections
|
2 |
+
import heapq
|
3 |
import json
|
4 |
import os
|
5 |
|
|
|
10 |
from open_clip import create_model, get_tokenizer
|
11 |
from torchvision import transforms
|
12 |
|
|
|
13 |
from templates import openai_imagenet_template
|
14 |
|
15 |
hf_token = os.getenv("HF_TOKEN")
|
16 |
|
17 |
model_str = "hf-hub:imageomics/bioclip"
|
18 |
tokenizer_str = "ViT-B-16"
|
19 |
+
|
20 |
+
txt_emb_npy = "txt_emb_species.npy"
|
21 |
+
txt_names_json = "txt_emb_species.json"
|
22 |
+
|
23 |
+
min_prob = 1e-9
|
24 |
+
k = 5
|
25 |
|
26 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
27 |
|
|
|
38 |
|
39 |
ranks = ("Kingdom", "Phylum", "Class", "Order", "Family", "Genus", "Species")
|
40 |
|
41 |
+
open_domain_examples = [
|
42 |
+
["examples/Ursus-arctos.jpeg", "Species"],
|
43 |
+
["examples/Phoca-vitulina.png", "Species"],
|
44 |
+
["examples/Felis-catus.jpeg", "Genus"],
|
45 |
+
["examples/Sarcoscypha-coccinea.jpeg", "Order"],
|
46 |
+
]
|
47 |
zero_shot_examples = [
|
48 |
[
|
49 |
"examples/Ursus-arctos.jpeg",
|
|
|
78 |
]
|
79 |
|
80 |
|
81 |
+
def indexed(lst, indices):
|
82 |
+
return [lst[i] for i in indices]
|
83 |
+
|
84 |
+
|
85 |
@torch.no_grad()
|
86 |
def get_txt_features(classnames, templates):
|
87 |
all_features = []
|
|
|
111 |
|
112 |
|
113 |
@torch.no_grad()
|
114 |
+
def open_domain_classification(img, rank: int) -> dict[str, float]:
|
115 |
"""
|
116 |
+
Predicts from the entire tree of life.
|
117 |
+
If targeting a higher rank than species, then this function predicts among all
|
118 |
+
species, then sums up species-level probabilities for the given rank.
|
119 |
"""
|
120 |
img = preprocess_img(img).to(device)
|
121 |
img_features = model.encode_image(img.unsqueeze(0))
|
122 |
img_features = F.normalize(img_features, dim=-1)
|
123 |
|
124 |
+
logits = (model.logit_scale.exp() * img_features @ txt_emb).squeeze()
|
125 |
+
probs = F.softmax(logits, dim=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
|
127 |
+
# If predicting species, no need to sum probabilities.
|
128 |
+
if rank + 1 == len(ranks):
|
129 |
+
topk = probs.topk(k)
|
130 |
+
return {
|
131 |
+
" ".join(txt_names[i]): prob for i, prob in zip(topk.indices, topk.values)
|
132 |
+
}
|
133 |
|
134 |
+
# Sum up by the rank
|
135 |
+
output = collections.defaultdict(float)
|
136 |
+
for i in torch.nonzero(probs > min_prob).squeeze():
|
137 |
+
output[" ".join(txt_names[i][: rank + 1])] += probs[i]
|
138 |
|
139 |
+
topk_names = heapq.nlargest(k, output, key=output.get)
|
140 |
|
141 |
+
return {name: output[name] for name in topk_names}
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
|
143 |
|
144 |
+
def change_output(choice):
|
145 |
+
return gr.Label(num_top_classes=k, label=ranks[choice], show_label=True, value=None)
|
|
|
146 |
|
147 |
|
148 |
if __name__ == "__main__":
|
|
|
156 |
|
157 |
tokenizer = get_tokenizer(tokenizer_str)
|
158 |
|
159 |
+
txt_emb = torch.from_numpy(np.load(txt_emb_npy, mmap_mode="r")).to(device)
|
160 |
+
with open(txt_names_json) as fd:
|
161 |
+
txt_names = json.load(fd)
|
162 |
|
163 |
done = txt_emb.any(axis=0).sum().item()
|
164 |
total = txt_emb.shape[1]
|
|
|
169 |
with gr.Blocks() as app:
|
170 |
img_input = gr.Image(height=512)
|
171 |
|
172 |
+
with gr.Tab("Open-Ended"):
|
173 |
+
with gr.Row():
|
174 |
+
with gr.Column():
|
175 |
+
rank_dropdown = gr.Dropdown(
|
176 |
+
label="Taxonomic Rank",
|
177 |
+
info="Which taxonomic rank to predict. Fine-grained ranks (genus, species) are more challenging.",
|
178 |
+
choices=ranks,
|
179 |
+
value="Species",
|
180 |
+
type="index",
|
181 |
+
)
|
182 |
+
open_domain_btn = gr.Button("Submit", variant="primary")
|
183 |
+
with gr.Column():
|
184 |
+
open_domain_output = gr.Label(
|
185 |
+
num_top_classes=k,
|
186 |
+
label="Prediction",
|
187 |
+
show_label=True,
|
188 |
+
value=None,
|
189 |
+
)
|
190 |
+
open_domain_flag_btn = gr.Button("Flag Mistake", variant="primary")
|
191 |
+
|
192 |
+
with gr.Row():
|
193 |
+
gr.Examples(
|
194 |
+
examples=open_domain_examples,
|
195 |
+
inputs=[img_input, rank_dropdown],
|
196 |
+
cache_examples=True,
|
197 |
+
fn=open_domain_classification,
|
198 |
+
outputs=[open_domain_output],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
199 |
)
|
|
|
200 |
|
201 |
+
open_domain_callback = gr.HuggingFaceDatasetSaver(
|
202 |
+
hf_token, "imageomics/bioclip-demo-open-domain-mistakes", private=True
|
203 |
+
)
|
204 |
+
open_domain_callback.setup(
|
205 |
+
[img_input, rank_dropdown, open_domain_output],
|
206 |
+
flagging_dir="logs/flagged",
|
207 |
+
)
|
208 |
+
open_domain_flag_btn.click(
|
209 |
+
lambda *args: open_domain_callback.flag(args),
|
210 |
+
[img_input, rank_dropdown, open_domain_output],
|
211 |
+
None,
|
212 |
+
preprocess=False,
|
|
|
213 |
)
|
214 |
|
215 |
+
with gr.Tab("Zero-Shot"):
|
216 |
+
with gr.Row():
|
217 |
+
with gr.Column():
|
218 |
+
classes_txt = gr.Textbox(
|
219 |
+
placeholder="Canis familiaris (dog)\nFelis catus (cat)\n...",
|
220 |
+
lines=3,
|
221 |
+
label="Classes",
|
222 |
+
show_label=True,
|
223 |
+
info="Use taxonomic names where possible; include common names if possible.",
|
224 |
+
)
|
225 |
+
zero_shot_btn = gr.Button("Submit", variant="primary")
|
226 |
+
|
227 |
+
with gr.Column():
|
228 |
+
zero_shot_output = gr.Label(
|
229 |
+
num_top_classes=k, label="Prediction", show_label=True
|
230 |
+
)
|
231 |
+
zero_shot_flag_btn = gr.Button("Flag Mistake", variant="primary")
|
232 |
+
|
233 |
+
with gr.Row():
|
234 |
+
gr.Examples(
|
235 |
+
examples=zero_shot_examples,
|
236 |
+
inputs=[img_input, classes_txt],
|
237 |
+
cache_examples=True,
|
238 |
+
fn=zero_shot_classification,
|
239 |
+
outputs=[zero_shot_output],
|
240 |
+
)
|
241 |
+
|
242 |
zero_shot_callback = gr.HuggingFaceDatasetSaver(
|
243 |
hf_token, "imageomics/bioclip-demo-zero-shot-mistakes", private=True
|
244 |
)
|
|
|
252 |
preprocess=False,
|
253 |
)
|
254 |
|
255 |
+
rank_dropdown.change(
|
256 |
+
fn=change_output, inputs=rank_dropdown, outputs=[open_domain_output]
|
257 |
+
)
|
258 |
|
259 |
+
open_domain_btn.click(
|
260 |
+
fn=open_domain_classification,
|
261 |
+
inputs=[img_input, rank_dropdown],
|
262 |
+
outputs=[open_domain_output],
|
263 |
+
)
|
264 |
|
265 |
zero_shot_btn.click(
|
266 |
fn=zero_shot_classification,
|
make_txt_embedding.py
CHANGED
@@ -112,6 +112,26 @@ def convert_txt_features_to_avgs(name_lookup):
|
|
112 |
)
|
113 |
|
114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
def get_name_lookup(catalog_path, cache_path):
|
116 |
if os.path.isfile(cache_path):
|
117 |
with open(cache_path) as fd:
|
@@ -170,3 +190,4 @@ if __name__ == "__main__":
|
|
170 |
tokenizer = get_tokenizer(tokenizer_str)
|
171 |
write_txt_features(name_lookup)
|
172 |
convert_txt_features_to_avgs(name_lookup)
|
|
|
|
112 |
)
|
113 |
|
114 |
|
115 |
+
def convert_txt_features_to_species_only(name_lookup):
|
116 |
+
assert os.path.isfile(args.out_path)
|
117 |
+
|
118 |
+
all_features = np.load(args.out_path)
|
119 |
+
logger.info("Loaded text features from disk.")
|
120 |
+
|
121 |
+
species = [(d, i) for d, i in name_lookup.descendants() if len(d) == 7]
|
122 |
+
species_features = np.zeros((512, len(species)), dtype=np.float32)
|
123 |
+
species_names = [""] * len(species)
|
124 |
+
|
125 |
+
for new_i, (name, old_i) in enumerate(tqdm(species)):
|
126 |
+
species_features[:, new_i] = all_features[:, old_i]
|
127 |
+
species_names[new_i] = name
|
128 |
+
|
129 |
+
out_path, ext = os.path.splitext(args.out_path)
|
130 |
+
np.save(f"{out_path}_species{ext}", species_features)
|
131 |
+
with open(f"{out_path}_species.json", "w") as fd:
|
132 |
+
json.dump(species_names, fd, indent=2)
|
133 |
+
|
134 |
+
|
135 |
def get_name_lookup(catalog_path, cache_path):
|
136 |
if os.path.isfile(cache_path):
|
137 |
with open(cache_path) as fd:
|
|
|
190 |
tokenizer = get_tokenizer(tokenizer_str)
|
191 |
write_txt_features(name_lookup)
|
192 |
convert_txt_features_to_avgs(name_lookup)
|
193 |
+
convert_txt_features_to_species_only(name_lookup)
|
txt_emb_species.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c71babd1b7bc275a1dbb12fd36e6329bcc2487784c0b7be10c2f4d0031d34211
|
3 |
+
size 50445969
|
txt_emb_species.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:91ce02dff2433222e3138b8bf7eefa1dd74b30f4d406c16cd3301f66d65ab4ed
|
3 |
+
size 787435648
|