vilarin commited on
Commit
3269b1e
1 Parent(s): b769a0c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -14
app.py CHANGED
@@ -13,16 +13,23 @@ HF_TOKEN = os.environ.get("HF_TOKEN", None)
13
  MODEL_ID = os.environ.get("MODEL_ID")
14
  MODEL_NAME = MODEL_ID.split("/")[-1]
15
 
16
- TITLE = "<h1><center>VL-Chatbox</center></h1>"
17
 
18
- DESCRIPTION = f'<h3><center>MODEL: <a href="https://hf.co/{MODEL_ID}">{MODEL_NAME}</a></center></h3>'
19
 
20
  CSS = """
21
  .duplicate-button {
22
- margin: auto !important;
23
- color: white !important;
24
- background: black !important;
25
- border-radius: 100vh !important;
 
 
 
 
 
 
 
26
  }
27
  """
28
 
@@ -37,7 +44,7 @@ model.eval()
37
 
38
 
39
  @spaces.GPU()
40
- def stream_chat(message, history: list, temperature: float, max_length: int):
41
  print(f'message is - {message}')
42
  print(f'history is - {history}')
43
  conversation = []
@@ -46,8 +53,9 @@ def stream_chat(message, history: list, temperature: float, max_length: int):
46
  conversation.append({"role": "user", "image": image, "content": message['text']})
47
  else:
48
  if len(history) == 0:
49
- raise gr.Error("Please upload an image first.")
50
  image = None
 
51
  else:
52
  image = Image.open(history[0][0][0])
53
  for prompt, answer in history:
@@ -65,9 +73,11 @@ def stream_chat(message, history: list, temperature: float, max_length: int):
65
  max_length=max_length,
66
  streamer=streamer,
67
  do_sample=True,
68
- top_k=1,
 
69
  temperature=temperature,
70
- repetition_penalty=1.2,
 
71
  )
72
  gen_kwargs = {**input_ids, **generate_kwargs}
73
 
@@ -91,9 +101,9 @@ chat_input = gr.MultimodalTextbox(
91
 
92
  )
93
  EXAMPLES = [
94
- [{"text": "Describe it in great detailed.", "files": ["./laptop.jpg"]}],
95
- [{"text": "Describe it in great detailed.", "files": ["./hotel.jpg"]}],
96
- [{"text": "Describe it in great detailed.", "files": ["./spacecat.png"]}]
97
  ]
98
 
99
  with gr.Blocks(css=CSS) as demo:
@@ -118,12 +128,37 @@ with gr.Blocks(css=CSS) as demo:
118
  ),
119
  gr.Slider(
120
  minimum=128,
121
- maximum=4096,
122
  step=1,
123
  value=1024,
124
  label="Max Length",
125
  render=False,
126
  ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  ],
128
  ),
129
  gr.Examples(EXAMPLES,[chat_input])
 
13
  MODEL_ID = os.environ.get("MODEL_ID")
14
  MODEL_NAME = MODEL_ID.split("/")[-1]
15
 
16
+ TITLE = "<h1>VL-Chatbox</h1>"
17
 
18
+ DESCRIPTION = f'<p>A SPACE FOR VLM MODELS</p><br><h3><center>MODEL NOW: <a href="https://hf.co/{MODEL_ID}">{MODEL_NAME}</a></center></h3>'
19
 
20
  CSS = """
21
  .duplicate-button {
22
+ margin: auto !important;
23
+ color: white !important;
24
+ background: black !important;
25
+ border-radius: 100vh !important;
26
+ }
27
+ h1 {
28
+ text-align: center;
29
+ display: block;
30
+ }
31
+ p {
32
+ text-align: center;
33
  }
34
  """
35
 
 
44
 
45
 
46
  @spaces.GPU()
47
+ def stream_chat(message, history: list, temperature: float, max_length: int, top_p: float, top_k: int, penalty: float):
48
  print(f'message is - {message}')
49
  print(f'history is - {history}')
50
  conversation = []
 
53
  conversation.append({"role": "user", "image": image, "content": message['text']})
54
  else:
55
  if len(history) == 0:
56
+ #raise gr.Error("Please upload an image first.")
57
  image = None
58
+ conversation.append({"role": "user", "content": message['text']})
59
  else:
60
  image = Image.open(history[0][0][0])
61
  for prompt, answer in history:
 
73
  max_length=max_length,
74
  streamer=streamer,
75
  do_sample=True,
76
+ top_p=top_p,
77
+ top_k=top_k,
78
  temperature=temperature,
79
+ repetition_penalty=penalty,
80
+ eos_token_id=[151329, 151336, 151338],
81
  )
82
  gen_kwargs = {**input_ids, **generate_kwargs}
83
 
 
101
 
102
  )
103
  EXAMPLES = [
104
+ [{"text": "Describe it in detailed", "files": ["./laptop.jpg"]}],
105
+ [{"text": "Where it is?", "files": ["./hotel.jpg"]}],
106
+ [{"text": "Is it real?", "files": ["./spacecat.png"]}]
107
  ]
108
 
109
  with gr.Blocks(css=CSS) as demo:
 
128
  ),
129
  gr.Slider(
130
  minimum=128,
131
+ maximum=8192,
132
  step=1,
133
  value=1024,
134
  label="Max Length",
135
  render=False,
136
  ),
137
+ with gr.Row():
138
+ gr.Slider(
139
+ minimum=0.0,
140
+ maximum=1.0,
141
+ step=0.1,
142
+ value=1.0,
143
+ label="top_p",
144
+ render=False,
145
+ ),
146
+ gr.Slider(
147
+ minimum=1,
148
+ maximum=20,
149
+ step=1,
150
+ value=10,
151
+ label="top_k",
152
+ render=False,
153
+ ),
154
+ gr.Slider(
155
+ minimum=0.0,
156
+ maximum=2.0,
157
+ step=0.1,
158
+ value=1.0,
159
+ label="Repetition penalty",
160
+ render=False,
161
+ ),
162
  ],
163
  ),
164
  gr.Examples(EXAMPLES,[chat_input])