ariG23498 HF staff commited on
Commit
1e868bb
·
1 Parent(s): 3352348
Files changed (3) hide show
  1. app.py +91 -0
  2. baklava.jpg +0 -0
  3. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This space is taken and modified from https://huggingface.co/spaces/merve/compare_clip_siglip"""
2
+ import os
3
+ os.environ["GRADIO_TEMP_DIR"] = "~/.cache/"
4
+
5
+ import torch
6
+ from transformers import AutoModel, AutoProcessor
7
+ import numpy as np
8
+ import gradio as gr
9
+ import spaces
10
+
11
+ ################################################################################
12
+ # Load the models
13
+ ################################################################################
14
+ sg1_ckpt = "google/siglip-so400m-patch14-384"
15
+ siglip1_model = AutoModel.from_pretrained(sg1_ckpt, device_map="auto").eval()
16
+ siglip1_processor = AutoProcessor.from_pretrained(sg1_ckpt)
17
+
18
+ sg2_ckpt = "s0225/siglip2-so400m-patch14-384"
19
+ siglip2_model = AutoModel.from_pretrained(sg2_ckpt, device_map="auto").eval()
20
+ siglip2_processor = AutoProcessor.from_pretrained(sg2_ckpt)
21
+
22
+ ################################################################################
23
+ # Utilities
24
+ ################################################################################
25
+ def postprocess(output):
26
+ return {out["label"]: float(out["score"]) for out in output}
27
+
28
+
29
+ def postprocess_siglip(sg1_probs, sg2_probs, labels):
30
+ sg1_output = {labels[i]: float(np.array(sg1_probs[0])[i]) for i in range(len(labels))}
31
+ sg2_output = {labels[i]: float(np.array(sg2_probs[0])[i]) for i in range(len(labels))}
32
+ return sg1_output, sg2_output
33
+
34
+ @spaces.GPU
35
+ def siglip_detector(image, texts):
36
+ sg1_inputs = siglip1_processor(
37
+ text=texts, images=image, return_tensors="pt", padding="max_length", max_length=64
38
+ ).to(siglip1_model.device)
39
+
40
+ sg2_inputs = siglip2_processor(
41
+ text=texts, images=image, return_tensors="pt", padding="max_length", max_length=64
42
+ ).to(siglip2_model.device)
43
+
44
+ with torch.no_grad():
45
+ sg1_outputs = siglip1_model(**sg1_inputs)
46
+ sg2_outputs = siglip2_model(**sg2_inputs)
47
+
48
+ sg1_logits_per_image = sg1_outputs.logits_per_image
49
+ sg2_logits_per_image = sg2_outputs.logits_per_image
50
+
51
+ sg1_probs = torch.sigmoid(sg1_logits_per_image)
52
+ sg2_probs = torch.sigmoid(sg2_logits_per_image)
53
+ return sg1_probs, sg2_probs
54
+
55
+
56
+ def infer(image, candidate_labels):
57
+ candidate_labels = [label.lstrip(" ") for label in candidate_labels.split(",")]
58
+ sg1_probs, sg2_probs = siglip_detector(image, candidate_labels)
59
+ return postprocess_siglip(
60
+ sg1_probs, sg2_probs, labels=candidate_labels
61
+ )
62
+
63
+
64
+ with gr.Blocks() as demo:
65
+ gr.Markdown("# Compare SigLIP 1 and SigLIP 2")
66
+ gr.Markdown(
67
+ "Compare the performance of SigLIP 1 and SigLIP 2 on zero-shot classification in this Space 👇"
68
+ )
69
+ with gr.Row():
70
+ with gr.Column():
71
+ image_input = gr.Image(type="pil")
72
+ text_input = gr.Textbox(label="Input a list of labels (comma seperated)")
73
+ run_button = gr.Button("Run", visible=True)
74
+
75
+ with gr.Column():
76
+ siglip1_output = gr.Label(label="SigLIP 1 Output", num_top_classes=3)
77
+ siglip2_output = gr.Label(label="SigLIP 2 Output", num_top_classes=3)
78
+
79
+ examples = [["./baklava.jpg", "baklava, souffle, tiramisu"]]
80
+ gr.Examples(
81
+ examples=examples,
82
+ inputs=[image_input, text_input],
83
+ outputs=[siglip1_output, siglip2_output],
84
+ fn=infer,
85
+ cache_examples=True,
86
+ )
87
+ run_button.click(
88
+ fn=infer, inputs=[image_input, text_input], outputs=[siglip1_output, siglip2_output]
89
+ )
90
+
91
+ demo.launch()
baklava.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ transformers
4
+ sentencepiece
5
+ pillow
6
+ protobuf
7
+ accelerate
8
+ spaces