not-lain commited on
Commit
9a53af3
β€’
1 Parent(s): 6808732

🌘wπŸŒ–

Browse files
Files changed (1) hide show
  1. app.py +17 -10
app.py CHANGED
@@ -16,22 +16,29 @@ model = AutoModelForZeroShotImageClassification.from_pretrained(
16
 
17
 
18
  @spaces.GPU
19
- def search(query: str, k: int = 4 ):
20
  """a function that embeds a new image and returns the most probable results"""
21
 
22
- pixel_values = processor(images = query, return_tensors="pt")['pixel_values'] # embed new image
 
 
23
  pixel_values = pixel_values.to(device)
24
- img_emb = model.get_image_features(pixel_values)[0] # because 1 element
25
- img_emb = img_emb.cpu().detach().numpy() # because datasets only works with numpy
26
 
27
- scores, retrieved_examples = dataset.get_nearest_examples( # retrieve results
28
- "embeddings", img_emb, # compare our new embedded query with the dataset embeddings
29
- k=k # get only top k results
 
30
  )
 
 
 
 
31
 
32
- return retrieved_examples["image"]
33
 
34
 
35
- demo = gr.Interface(search, inputs="image", outputs=["gallery"])
36
 
37
- demo.launch(debug=True)
 
16
 
17
 
18
  @spaces.GPU
19
+ def search(query: str, k: int = 4):
20
  """a function that embeds a new image and returns the most probable results"""
21
 
22
+ pixel_values = processor(images=query, return_tensors="pt")[
23
+ "pixel_values"
24
+ ] # embed new image
25
  pixel_values = pixel_values.to(device)
26
+ img_emb = model.get_image_features(pixel_values)[0] # because 1 element
27
+ img_emb = img_emb.cpu().detach().numpy() # because datasets only works with numpy
28
 
29
+ scores, retrieved_examples = dataset.get_nearest_examples( # retrieve results
30
+ "embeddings",
31
+ img_emb, # compare our new embedded query with the dataset embeddings
32
+ k=k, # get only top k results
33
  )
34
+ images = retrieved_examples["image"]
35
+ labels = {}
36
+ for i in range(k):
37
+ labels[retrieved_examples["text"][i]] = scores[i]
38
 
39
+ return images, labels
40
 
41
 
42
+ demo = gr.Interface(search, inputs="image", outputs=["gallery", "label"])
43
 
44
+ demo.launch(debug=True)