Dongxu Li commited on
Commit
120a3c2
1 Parent(s): 418bb25
Files changed (1) hide show
  1. app.py +111 -0
app.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+
3
+ import requests
4
+ import json
5
+ import gradio as gr
6
+
7
+
8
+ from io import BytesIO
9
+
10
+ def encode_image(image):
11
+ buffered = BytesIO()
12
+ image.save(buffered, format="JPEG")
13
+ buffered.seek(0)
14
+
15
+ return buffered
16
+
17
+
18
+ def query_api(image, prompt, decoding_method):
19
+ # local host for testing
20
+ url = "http://34.132.142.70:5000/api/generate"
21
+
22
+ data = {"prompt": prompt, "use_nucleus_sampling": decoding_method == "Nucleus sampling"}
23
+
24
+ image = encode_image(image)
25
+ files = {"image": image}
26
+
27
+ response = requests.post(url, data=data, files=files)
28
+
29
+ if response.status_code == 200:
30
+ return response.json()
31
+ else:
32
+ return "Error: " + response.text
33
+
34
+
35
+ def prepend_question(text):
36
+ text = text.strip().lower()
37
+
38
+ return "question: " + text
39
+
40
+
41
+ def prepend_answer(text):
42
+ text = text.strip().lower()
43
+
44
+ return "answer: " + text
45
+
46
+
47
+ def get_prompt_from_history(history):
48
+ prompts = []
49
+
50
+ for i in range(len(history)):
51
+ if i % 2 == 0:
52
+ prompts.append(prepend_question(history[i]))
53
+ else:
54
+ prompts.append(prepend_answer(history[i]))
55
+
56
+ return "\n".join(prompts)
57
+
58
+
59
+ def postp_answer(text):
60
+ if text.startswith("answer: "):
61
+ return text[8:]
62
+ elif text.startswith("a: "):
63
+ return text[2:]
64
+ else:
65
+ return text
66
+
67
+
68
+ def prep_question(text):
69
+ if text.startswith("question: "):
70
+ text = text[10:]
71
+ elif text.startswith("q: "):
72
+ text = text[2:]
73
+
74
+ if not text.endswith("?"):
75
+ text += "?"
76
+
77
+ return text
78
+
79
+
80
+ def inference(image, text_input, decoding_method, history=[]):
81
+ text_input = prep_question(text_input)
82
+ history.append(text_input)
83
+
84
+ # prompt = '\n'.join(history)
85
+ prompt = get_prompt_from_history(history)
86
+ # print("prompt: " + prompt)
87
+
88
+ output = query_api(image, prompt, decoding_method)
89
+ output = [postp_answer(output[0])]
90
+ history += output
91
+
92
+ chat = [(history[i], history[i+1]) for i in range(0, len(history)-1, 2)] # convert to tuples of list
93
+
94
+ return chat, history
95
+
96
+
97
+ inputs = [gr.inputs.Image(type='pil'),
98
+ gr.inputs.Textbox(lines=2, label="Text input"),
99
+ gr.inputs.Radio(choices=['Nucleus sampling','Beam search'], type="value", default="Nucleus sampling", label="Text Decoding Method"),
100
+ "state",
101
+ ]
102
+
103
+ outputs = ["chatbot", "state"]
104
+
105
+ title = "BLIP-2"
106
+ description = """Gradio demo for BLIP-2, a multimodal chatbot from Salesforce Research. To use it, simply upload your image, or click one of the examples to load them. Please visit our <a href='https://github.com/salesforce/LAVIS/tree/main/projects/blip2' target='_blank'>project webpage</a>.</p>
107
+ <p> <strong>Disclaimer</strong>: This is a research prototype and is not intended for production use. No data including but not restricted to text and images is collected. </p>"""
108
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2201.12086' target='_blank'>BLIP-2: Bootstrapping Language-Image Pre-training with Frozen Image Encoders and Large Language Models</a>"
109
+
110
+ iface = gr.Interface(inference, inputs, outputs, title=title, description=description, article=article)
111
+ iface.launch(enable_queue=True)