not-lain commited on
Commit
fe9fde2
β€’
1 Parent(s): 31d125f

🌘wπŸŒ–

Browse files
Files changed (1) hide show
  1. app.py +17 -10
app.py CHANGED
@@ -4,11 +4,12 @@ import torch
4
  from transformers import AutoProcessor, AutoModelForZeroShotImageClassification
5
  from datasets import load_dataset
6
 
 
7
  dataset = load_dataset("not-lain/embedded-pokemon", split="train")
8
  dataset = dataset.add_faiss_index("embeddings")
9
 
10
- device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
 
12
  processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14")
13
  model = AutoModelForZeroShotImageClassification.from_pretrained(
14
  "openai/clip-vit-large-patch14", device_map=device
@@ -31,17 +32,23 @@ def search(query: str, k: int = 4):
31
  img_emb, # compare our new embedded query with the dataset embeddings
32
  k=k, # get only top k results
33
  )
34
- images = retrieved_examples["image"]
35
- # labels = {}
36
- # for i in range(k):
37
- # labels[retrieved_examples["text"][k-i]] = scores[k-i]
38
 
39
- return images #, labels
 
 
 
 
 
 
40
 
41
- demo = gr.Interface(search, inputs="image", outputs=["gallery"
42
- #, "label"
43
- ],
44
- examples=["./charmander.jpg"],
 
 
 
 
45
  )
46
 
47
  demo.launch(debug=True)
 
4
  from transformers import AutoProcessor, AutoModelForZeroShotImageClassification
5
  from datasets import load_dataset
6
 
7
+
8
  dataset = load_dataset("not-lain/embedded-pokemon", split="train")
9
  dataset = dataset.add_faiss_index("embeddings")
10
 
 
11
 
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
  processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14")
14
  model = AutoModelForZeroShotImageClassification.from_pretrained(
15
  "openai/clip-vit-large-patch14", device_map=device
 
32
  img_emb, # compare our new embedded query with the dataset embeddings
33
  k=k, # get only top k results
34
  )
 
 
 
 
35
 
36
+ # return as image, caption pairs
37
+ out = []
38
+ for i in range(k):
39
+ out.append([retrieved_examples["image"][i], retrieved_examples["text"][i]])
40
+
41
+ return out
42
+
43
 
44
+ demo = gr.Interface(
45
+ search,
46
+ inputs="image",
47
+ outputs=[
48
+ "gallery"
49
+ # , "label"
50
+ ],
51
+ examples=["./charmander.jpg"],
52
  )
53
 
54
  demo.launch(debug=True)