shaktibiplab commited on
Commit
ce0408c
·
verified ·
1 Parent(s): df6a7f3

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -0
app.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import pipeline, AutoTokenizer
3
+ import gradio as gr
4
+
5
+ # Load tokenizer
6
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
7
+ tokenizer.clean_up_tokenization_spaces = False # Explicitly set the parameter if needed
8
+
9
+ # Load CLIP model for zero-shot classification
10
+ clip_checkpoint = "openai/clip-vit-base-patch16"
11
+ clip_detector = pipeline(model=clip_checkpoint, task="zero-shot-image-classification")
12
+
13
+ # Postprocess the output from CLIP
14
+ def postprocess(output):
15
+ return {out["label"]: float(out["score"]) for out in output}
16
+
17
+ # Inference function for CLIP
18
+ def infer(image, candidate_labels):
19
+ candidate_labels = [label.lstrip(" ") for label in candidate_labels.split(",")]
20
+ clip_out = clip_detector(image, candidate_labels=candidate_labels)
21
+ return postprocess(clip_out)
22
+
23
+ # Gradio interface
24
+ with gr.Blocks() as app:
25
+ gr.Markdown("# Custom Classification")
26
+ with gr.Row():
27
+ with gr.Column():
28
+ image_input = gr.Image(type="pil")
29
+ text_input = gr.Textbox(label="Input a list of labels")
30
+ run_button = gr.Button("Run")
31
+
32
+ with gr.Column():
33
+ clip_output = gr.Label(label="Output", num_top_classes=3)
34
+
35
+ examples = [["image_8.webp", "girl, boy, lgbtq"]]
36
+ gr.Examples(
37
+ examples=examples,
38
+ inputs=[image_input, text_input],
39
+ outputs=[clip_output],
40
+ fn=infer,
41
+ cache_examples=True
42
+ )
43
+
44
+ run_button.click(fn=infer,
45
+ inputs=[image_input, text_input],
46
+ outputs=[clip_output])
47
+
48
+ app.launch()