merve HF staff commited on
Commit
08bcb47
1 Parent(s): db8a6e8

Added back parameters

Browse files
Files changed (1) hide show
  1. app.py +76 -16
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import os
2
  import string
3
 
@@ -17,11 +19,9 @@ quantization_config = BitsAndBytesConfig(
17
  pipe = pipeline("image-to-text", model=model_id, model_kwargs={"quantization_config": quantization_config})
18
 
19
 
20
-
21
  def extract_response_pairs(text):
22
  pattern = re.compile(r'(USER:.*?)ASSISTANT:(.*?)(?:$|USER:)', re.DOTALL)
23
  matches = pattern.findall(text)
24
- print(matches)
25
 
26
  pairs = [(user.strip(), assistant.strip()) for user, assistant in matches]
27
 
@@ -35,19 +35,26 @@ def postprocess_output(output: str) -> str:
35
 
36
 
37
 
38
- def chat(image, text, max_length, history_chat):
 
 
39
 
40
- prompt = " ".join(history_chat) + f"USER: <image>\n{text}\nASSISTANT:"
 
41
 
42
  outputs = pipe(image, prompt=prompt,
43
- generate_kwargs={
44
- "max_length":max_length})
 
 
 
 
 
45
 
46
- #output = postprocess_output(outputs[0]["generated_text"])
47
- history_chat.append(outputs[0]["generated_text"])
48
 
49
  chat_val = extract_response_pairs(" ".join(history_chat))
50
-
51
  return chat_val, history_chat
52
 
53
 
@@ -60,33 +67,81 @@ css = """
60
  """
61
  with gr.Blocks(css="style.css") as demo:
62
  gr.Markdown(DESCRIPTION)
63
- gr.Markdown("LLaVA is now available in transformers with 4-bit quantization ⚡️")
 
 
64
  chatbot = gr.Chatbot(label="Chat", show_label=False)
65
- gr.Markdown("Input image and text to start chatting 👇 ")
66
  with gr.Row():
 
67
  image = gr.Image(type="pil")
68
- text_input = gr.Text(label="Chat Input", max_lines=1)
 
 
69
 
70
  history_chat = gr.State(value=[])
71
  with gr.Row():
72
  clear_chat_button = gr.Button("Clear")
73
  chat_button = gr.Button("Submit", variant="primary")
74
  with gr.Accordion(label="Advanced settings", open=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  max_length = gr.Slider(
76
  label="Max Length",
77
  minimum=1,
78
- maximum=200,
79
  step=1,
80
- value=150,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  )
82
-
83
  chat_output = [
84
  chatbot,
85
  history_chat
86
  ]
87
  chat_button.click(fn=chat, inputs=[image,
88
  text_input,
 
 
 
89
  max_length,
 
 
90
  history_chat],
91
  outputs=chat_output,
92
  api_name="Chat",
@@ -95,7 +150,12 @@ with gr.Blocks(css="style.css") as demo:
95
  chat_inputs = [
96
  image,
97
  text_input,
 
 
 
98
  max_length,
 
 
99
  history_chat
100
  ]
101
  text_input.submit(
@@ -130,4 +190,4 @@ with gr.Blocks(css="style.css") as demo:
130
 
131
 
132
  if __name__ == "__main__":
133
- demo.queue(max_size=10).launch(debug=True)
 
1
+ from __future__ import annotations
2
+
3
  import os
4
  import string
5
 
 
19
  pipe = pipeline("image-to-text", model=model_id, model_kwargs={"quantization_config": quantization_config})
20
 
21
 
 
22
  def extract_response_pairs(text):
23
  pattern = re.compile(r'(USER:.*?)ASSISTANT:(.*?)(?:$|USER:)', re.DOTALL)
24
  matches = pattern.findall(text)
 
25
 
26
  pairs = [(user.strip(), assistant.strip()) for user, assistant in matches]
27
 
 
35
 
36
 
37
 
38
+ def chat(image, text, temperature, length_penalty,
39
+ repetition_penalty, max_length, min_length, num_beams, top_p,
40
+ history_chat):
41
 
42
+ prompt = " ".join(history_chat)
43
+ prompt = f"USER: <image>\n{text}\nASSISTANT:"
44
 
45
  outputs = pipe(image, prompt=prompt,
46
+ generate_kwargs={"temperature":temperature,
47
+ "length_penalty":length_penalty,
48
+ "repetition_penalty":repetition_penalty,
49
+ "max_length":max_length,
50
+ "min_length":min_length,
51
+ "num_beams":num_beams,
52
+ "top_p":top_p})
53
 
54
+ output = postprocess_output(outputs[0]["generated_text"])
55
+ history_chat.append(output)
56
 
57
  chat_val = extract_response_pairs(" ".join(history_chat))
 
58
  return chat_val, history_chat
59
 
60
 
 
67
  """
68
  with gr.Blocks(css="style.css") as demo:
69
  gr.Markdown(DESCRIPTION)
70
+ gr.Markdown("**LLaVA, one of the greatest multimodal chat models is now available in transformers with 4-bit quantization! ⚡️ **")
71
+ gr.Markdown("**Try it in this demo 🤗 **")
72
+
73
  chatbot = gr.Chatbot(label="Chat", show_label=False)
74
+ gr.Markdown("Input image and text and start chatting 👇")
75
  with gr.Row():
76
+
77
  image = gr.Image(type="pil")
78
+ text_input = gr.Text(label="Chat Input", show_label=False, max_lines=3, container=False)
79
+
80
+
81
 
82
  history_chat = gr.State(value=[])
83
  with gr.Row():
84
  clear_chat_button = gr.Button("Clear")
85
  chat_button = gr.Button("Submit", variant="primary")
86
  with gr.Accordion(label="Advanced settings", open=False):
87
+ temperature = gr.Slider(
88
+ label="Temperature",
89
+ info="Used with nucleus sampling.",
90
+ minimum=0.5,
91
+ maximum=1.0,
92
+ step=0.1,
93
+ value=1.0,
94
+ )
95
+ length_penalty = gr.Slider(
96
+ label="Length Penalty",
97
+ info="Set to larger for longer sequence, used with beam search.",
98
+ minimum=-1.0,
99
+ maximum=2.0,
100
+ step=0.2,
101
+ value=1.0,
102
+ )
103
+ repetition_penalty = gr.Slider(
104
+ label="Repetition Penalty",
105
+ info="Larger value prevents repetition.",
106
+ minimum=1.0,
107
+ maximum=5.0,
108
+ step=0.5,
109
+ value=1.5,
110
+ )
111
  max_length = gr.Slider(
112
  label="Max Length",
113
  minimum=1,
114
+ maximum=512,
115
  step=1,
116
+ value=50,
117
+ )
118
+ min_length = gr.Slider(
119
+ label="Minimum Length",
120
+ minimum=1,
121
+ maximum=100,
122
+ step=1,
123
+ value=1,
124
+ )
125
+ top_p = gr.Slider(
126
+ label="Top P",
127
+ info="Used with nucleus sampling.",
128
+ minimum=0.5,
129
+ maximum=1.0,
130
+ step=0.1,
131
+ value=0.9,
132
  )
 
133
  chat_output = [
134
  chatbot,
135
  history_chat
136
  ]
137
  chat_button.click(fn=chat, inputs=[image,
138
  text_input,
139
+ temperature,
140
+ length_penalty,
141
+ repetition_penalty,
142
  max_length,
143
+ min_length,
144
+ top_p,
145
  history_chat],
146
  outputs=chat_output,
147
  api_name="Chat",
 
150
  chat_inputs = [
151
  image,
152
  text_input,
153
+ temperature,
154
+ length_penalty,
155
+ repetition_penalty,
156
  max_length,
157
+ min_length,
158
+ top_p,
159
  history_chat
160
  ]
161
  text_input.submit(
 
190
 
191
 
192
  if __name__ == "__main__":
193
+ demo.queue(max_size=10).launch()