stevengrove commited on
Commit
8ec41f0
1 Parent(s): 6e030f4

initial commit

Browse files
app.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import pickle
4
+ import base64
5
+ import requests
6
+ import argparse
7
+ import numpy as np
8
+ import gradio as gr
9
+
10
+ from functools import partial
11
+ from PIL import Image
12
+
13
+ SERVER_URL = os.getenv('SERVER_URL')
14
+
15
+
16
+ def get_images(state):
17
+ history = ''
18
+ for i in range(len(state)):
19
+ for j in range(len(state[i])):
20
+ history += state[i][j] + '\n'
21
+ for image_path in re.findall('image/[0-9,a-z]+\.png', history):
22
+ if os.path.exists(image_path):
23
+ continue
24
+ data = {'method': 'get_image', 'args': [image_path], 'kwargs': {}}
25
+ data = base64.b64encode(pickle.dumps(data)).decode('utf-8')
26
+ response = requests.post(SERVER_URL, json=data)
27
+ image = pickle.loads(base64.b64decode(response.json().encode('utf-8')))
28
+ image.save(image_path)
29
+
30
+
31
+ def bot_request(method, *args, **kwargs):
32
+ data = {'method': method, 'args': args, 'kwargs': kwargs}
33
+ data = base64.b64encode(pickle.dumps(data)).decode('utf-8')
34
+ response = requests.post(SERVER_URL, json=data)
35
+ response = pickle.loads(base64.b64decode(response.json().encode('utf-8')))
36
+ if response is not None:
37
+ state = response[0]
38
+ get_images(state)
39
+ return response
40
+
41
+
42
+ if __name__ == '__main__':
43
+ parser = argparse.ArgumentParser()
44
+ parser.add_argument('--temperature', type=float, default=0.0, help='temperature for the llm model')
45
+ parser.add_argument('--max_new_tokens', type=int, default=128, help='max number of new tokens to generate')
46
+ parser.add_argument('--top_p', type=float, default=1.0, help='top_p for the llm model')
47
+ parser.add_argument('--top_k', type=int, default=40, help='top_k for the llm model')
48
+ parser.add_argument('--num_beams', type=int, default=4, help='num_beams for the llm model')
49
+ parser.add_argument('--keep_last_n_paragraphs', type=int, default=1, help='keep last n paragraphs in the memory')
50
+ args = parser.parse_args()
51
+
52
+ examples = [
53
+ ['images/example-1.jpg', 'What is unusual about this image?'],
54
+ ['images/example-2.jpg', 'Make the image look like a cartoon.'],
55
+ ['images/example-3.jpg', 'Segment the tie in the image.'],
56
+ ['images/example-4.jpg', 'Generate a man watching a sea based on the pose of the woman.'],
57
+ ['images/example-5.jpg', 'Replace the dog with a cat.'],
58
+ ]
59
+
60
+ if not os.path.exists('image'):
61
+ os.makedirs('image')
62
+
63
+ with gr.Blocks() as demo:
64
+ with gr.Row():
65
+ with gr.Column(scale=0.3):
66
+ with gr.Row():
67
+ image = gr.Image(type="pil", label="input image")
68
+ with gr.Row():
69
+ txt = gr.Textbox(lines=7, show_label=False, elem_id="textbox",
70
+ placeholder="Enter text and press submit, or upload an image").style(container=False)
71
+ with gr.Row():
72
+ submit = gr.Button("Submit")
73
+ with gr.Row():
74
+ clear = gr.Button("Clear")
75
+ with gr.Row():
76
+ keep_last_n_paragraphs = gr.Slider(
77
+ minimum=0,
78
+ maximum=3,
79
+ value=args.keep_last_n_paragraphs,
80
+ step=1,
81
+ interactive=True,
82
+ label="Remember Last N Paragraphs")
83
+ max_new_token = gr.Slider(
84
+ minimum=128,
85
+ maximum=1024,
86
+ value=args.max_new_tokens,
87
+ step=64,
88
+ interactive=True,
89
+ label="Max New Tokens")
90
+ temperature = gr.Slider(
91
+ minimum=0.0,
92
+ maximum=1.0,
93
+ value=args.temperature,
94
+ step=0.1,
95
+ interactive=True,
96
+ label="Temperature")
97
+ top_p = gr.Slider(
98
+ minimum=0.0,
99
+ maximum=1.0,
100
+ value=args.top_p,
101
+ step=0.1,
102
+ interactive=True,
103
+ label="Top P")
104
+ with gr.Column(scale=0.7):
105
+ chatbot = gr.Chatbot(elem_id="chatbot", label="🦙 GPT4Tools").style(height=690)
106
+ state = gr.State([])
107
+
108
+ txt.submit(partial(bot_request, 'run_text'), [txt, state], [chatbot, state])
109
+ txt.submit(lambda: "", None, txt)
110
+ image.upload(lambda: "", None, txt)
111
+ submit.click(partial(bot_request, 'run_image'), [image, state, txt], [chatbot, state, txt]).then(
112
+ partial(bot_request, 'run_text'), [txt, state, temperature, top_p, max_new_token, keep_last_n_paragraphs], [chatbot, state, txt]).then(
113
+ lambda: None, None, image)
114
+ clear.click(partial(bot_request, 'clear'))
115
+ clear.click(lambda: [], None, chatbot)
116
+ clear.click(lambda: [], None, state)
117
+ with gr.Row():
118
+ gr.Examples(
119
+ examples=examples,
120
+ inputs=[image, txt],
121
+ )
122
+ demo.launch()
images/example-1.jpg ADDED
images/example-2.jpg ADDED
images/example-3.jpg ADDED
images/example-4.jpg ADDED
images/example-5.jpg ADDED
images/overview.png ADDED
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ requests
2
+ base64
3
+ numpy
4
+ gradio
5
+ pillow
6
+ functools