ariG23498 HF staff commited on
Commit
c7c29fb
·
verified ·
1 Parent(s): a7ee8af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -22
app.py CHANGED
@@ -1,61 +1,75 @@
1
  """This space is taken and modified from https://huggingface.co/spaces/merve/compare_clip_siglip"""
2
- from transformers import pipeline
 
3
  import gradio as gr
4
 
5
  ################################################################################
6
  # Load the models
7
  ################################################################################
8
  sg1_ckpt = "google/siglip-so400m-patch14-384"
9
- sg1_pipe = pipeline(task="zero-shot-image-classification", model=sg1_ckpt, device="cpu")
 
10
 
11
  sg2_ckpt = "google/siglip2-so400m-patch14-384"
12
- sg2_pipe = pipeline(task="zero-shot-image-classification", model=sg2_ckpt, device="cpu")
 
 
13
 
14
  ################################################################################
15
- # Run inference
16
  ################################################################################
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def infer(image, candidate_labels):
18
  candidate_labels = [label.lstrip(" ") for label in candidate_labels.split(",")]
 
 
19
 
20
- sg1_socres = sg1_pipe(image, candidate_labels=candidate_labels)
21
- sg2_socres = sg2_pipe(image, candidate_labels=candidate_labels)
22
-
23
- sg1_outputs = {element["label"]:element["score"] for element in sg1_socres}
24
- sg2_outputs = {element["label"]:element["score"] for element in sg2_socres}
25
-
26
- return sg1_outputs, sg2_outputs
27
 
28
- ################################################################################
29
- # Gradio App
30
- ################################################################################
31
  with gr.Blocks() as demo:
32
  gr.Markdown("# Compare SigLIP 1 and SigLIP 2")
33
  gr.Markdown(
34
- "Compare the performance of SigLIP 1 and SigLIP 2 on zero-shot classification in this Space 👇"
35
  )
36
  with gr.Row():
37
  with gr.Column():
38
  image_input = gr.Image(type="pil")
39
  text_input = gr.Textbox(label="Input a list of labels (comma seperated)")
40
  run_button = gr.Button("Run", visible=True)
41
-
42
  with gr.Column():
43
  siglip1_output = gr.Label(label="SigLIP 1 Output", num_top_classes=3)
44
  siglip2_output = gr.Label(label="SigLIP 2 Output", num_top_classes=3)
45
-
46
  examples = [
47
  ["./baklava.jpg", "dessert on a plate, a serving of baklava, a plate and spoon"],
48
  ["./cat.jpg", "a cat, two cats, three cats"],
49
  ["./cat.jpg", "two sleeping cats, two cats playing, three cats laying down"],
50
- ]
51
  gr.Examples(
52
  examples=examples,
53
  inputs=[image_input, text_input],
54
  outputs=[siglip1_output, siglip2_output],
55
  fn=infer,
56
  )
57
- run_button.click(
58
- fn=infer, inputs=[image_input, text_input], outputs=[siglip1_output, siglip2_output]
59
- )
60
-
61
  demo.launch()
 
1
  """This space is taken and modified from https://huggingface.co/spaces/merve/compare_clip_siglip"""
2
+ import torch
3
+ from transformers import AutoModel, AutoProcessor
4
  import gradio as gr
5
 
6
  ################################################################################
7
  # Load the models
8
  ################################################################################
9
  sg1_ckpt = "google/siglip-so400m-patch14-384"
10
+ siglip1_model = AutoModel.from_pretrained(sg1_ckpt, device_map="cpu").eval()
11
+ siglip1_processor = AutoProcessor.from_pretrained(sg1_ckpt)
12
 
13
  sg2_ckpt = "google/siglip2-so400m-patch14-384"
14
+ siglip2_model = AutoModel.from_pretrained(sg2_ckpt, device_map="cpu").eval()
15
+ siglip2_processor = AutoProcessor.from_pretrained(sg2_ckpt)
16
+
17
 
18
  ################################################################################
19
+ # Utilities
20
  ################################################################################
21
+ def postprocess_siglip(sg1_probs, sg2_probs, labels):
22
+ sg1_output = {labels[i]: sg1_probs[0][i] for i in range(len(labels))}
23
+ sg2_output = {labels[i]: sg2_probs[0][i] for i in range(len(labels))}
24
+ return sg1_output, sg2_output
25
+
26
+
27
+ def siglip_detector(image, texts):
28
+ sg1_inputs = siglip1_processor(
29
+ text=texts, images=image, return_tensors="pt", padding="max_length", max_length=64
30
+ ).to("cpu")
31
+ sg2_inputs = siglip2_processor(
32
+ text=texts, images=image, return_tensors="pt", padding="max_length", max_length=64
33
+ ).to("cpu")
34
+ with torch.no_grad():
35
+ sg1_outputs = siglip1_model(**sg1_inputs)
36
+ sg2_outputs = siglip2_model(**sg2_inputs)
37
+ sg1_logits_per_image = sg1_outputs.logits_per_image
38
+ sg2_logits_per_image = sg2_outputs.logits_per_image
39
+ sg1_probs = torch.sigmoid(sg1_logits_per_image)
40
+ sg2_probs = torch.sigmoid(sg2_logits_per_image)
41
+ return sg1_probs, sg2_probs
42
+
43
+
44
  def infer(image, candidate_labels):
45
  candidate_labels = [label.lstrip(" ") for label in candidate_labels.split(",")]
46
+ sg1_probs, sg2_probs = siglip_detector(image, candidate_labels)
47
+ return postprocess_siglip(sg1_probs, sg2_probs, labels=candidate_labels)
48
 
 
 
 
 
 
 
 
49
 
 
 
 
50
  with gr.Blocks() as demo:
51
  gr.Markdown("# Compare SigLIP 1 and SigLIP 2")
52
  gr.Markdown(
53
+ "Compare the performance of SigLIP 1 and SigLIP 2 on zero-shot classification in this Space :point_down:"
54
  )
55
  with gr.Row():
56
  with gr.Column():
57
  image_input = gr.Image(type="pil")
58
  text_input = gr.Textbox(label="Input a list of labels (comma seperated)")
59
  run_button = gr.Button("Run", visible=True)
 
60
  with gr.Column():
61
  siglip1_output = gr.Label(label="SigLIP 1 Output", num_top_classes=3)
62
  siglip2_output = gr.Label(label="SigLIP 2 Output", num_top_classes=3)
 
63
  examples = [
64
  ["./baklava.jpg", "dessert on a plate, a serving of baklava, a plate and spoon"],
65
  ["./cat.jpg", "a cat, two cats, three cats"],
66
  ["./cat.jpg", "two sleeping cats, two cats playing, three cats laying down"],
67
+ ]
68
  gr.Examples(
69
  examples=examples,
70
  inputs=[image_input, text_input],
71
  outputs=[siglip1_output, siglip2_output],
72
  fn=infer,
73
  )
74
+ run_button.click(fn=infer, inputs=[image_input, text_input], outputs=[siglip1_output, siglip2_output])
 
 
 
75
  demo.launch()