user-agent commited on
Commit
a4dc223
·
verified ·
1 Parent(s): 6e0ffb7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -13
app.py CHANGED
@@ -1,4 +1,6 @@
1
  from turtle import title
 
 
2
  import gradio as gr
3
  from transformers import pipeline
4
  import numpy as np
@@ -9,21 +11,43 @@ pipe = pipeline("zero-shot-image-classification", model="patrickjohncyh/fashion-
9
  images="dog.jpg"
10
 
11
  @spaces.GPU
12
- def shot(image, labels_text):
13
- PIL_image = Image.fromarray(np.uint8(image)).convert('RGB')
 
 
 
 
 
 
 
 
 
14
  labels = labels_text.split(",")
 
 
15
  res = pipe(images=PIL_image,
16
- candidate_labels=labels,
17
- hypothesis_template= "This is a photo of a {}")
18
- return {dic["label"]: dic["score"] for dic in res}
19
 
20
- iface = gr.Interface(shot,
21
- ["image", "text"],
22
- "label",
23
- examples=[["dog.jpg", "dog,cat,bird"],
24
- ["germany.jpg", "germany,belgium,colombia"],
25
- ["colombia.jpg", "germany,belgium,colombia"]],
26
- description="Add a picture and a list of labels separated by commas",
27
- title="Zero-shot Image Classification")
 
 
 
 
 
 
 
 
 
 
 
28
 
 
29
  iface.launch()
 
1
  from turtle import title
2
+ import requests
3
+ from io import BytesIO
4
  import gradio as gr
5
  from transformers import pipeline
6
  import numpy as np
 
11
  images="dog.jpg"
12
 
13
  @spaces.GPU
14
+ def shot(input, labels_text):
15
+ # Check if the input is a URL or an uploaded image
16
+ if isinstance(input, str) and (input.startswith("http://") or input.startswith("https://")):
17
+ # Input is a URL
18
+ response = requests.get(input)
19
+ PIL_image = Image.open(BytesIO(response.content)).convert('RGB')
20
+ else:
21
+ # Input is an uploaded image
22
+ PIL_image = Image.fromarray(np.uint8(input)).convert('RGB')
23
+
24
+ # Split labels into a list
25
  labels = labels_text.split(",")
26
+
27
+ # Perform the zero-shot image classification
28
  res = pipe(images=PIL_image,
29
+ candidate_labels=labels,
30
+ hypothesis_template="This is a photo of a {}")
 
31
 
32
+ # Return the classification results as a dictionary
33
+ return {dic["label"]: dic["score"] for dic in res}
34
+
35
+ # Define the Gradio interface
36
+ iface = gr.Interface(
37
+ fn=shot,
38
+ inputs=[
39
+ gr.inputs.Textbox(label="Image URL (starting with http/https) or Upload Image"),
40
+ "text"
41
+ ],
42
+ outputs="label",
43
+ examples=[
44
+ ["https://example.com/dog.jpg", "dog,cat,bird"],
45
+ ["https://example.com/germany.jpg", "germany,belgium,colombia"],
46
+ ["https://example.com/colombia.jpg", "germany,belgium,colombia"]
47
+ ],
48
+ description="Add an image URL (starting with http/https) or upload a picture, and provide a list of labels separated by commas.",
49
+ title="Zero-shot Image Classification"
50
+ )
51
 
52
+ # Launch the interface
53
  iface.launch()