zxw commited on
Commit
23939ce
1 Parent(s): 14c7cad
Files changed (1) hide show
  1. app.py +26 -6
app.py CHANGED
@@ -122,15 +122,26 @@ def preprocess(text, task):
122
 
123
  return task_style_to_task_prefix[task] + "\n" + text + "\n答案:"
124
 
125
- def inference_gen(text, task):
126
  text = preprocess(text, task)
 
 
 
 
 
 
 
 
 
 
127
  #print(text)
128
  try:
129
  prediction = cl.generate(
130
  model_name='clueai-base',
131
- prompt=text)
 
132
  except Exception as e:
133
- logger.error(f"error, e")
134
  return
135
 
136
  return prediction.generations[0].text
@@ -141,7 +152,9 @@ from io import BytesIO
141
  from PIL import Image
142
  def inference_image(text, n_text, guidance_scale, style, shape, clarity, steps, shape_scale):
143
  try:
144
- res = requests.get(f"https://www.clueai.cn/clueai/hf_text2image?text={text}")
 
 
145
  except Exception as e:
146
  logger.error(f"error, {e}")
147
  return
@@ -167,14 +180,21 @@ with gr.Blocks(css=css, title="ClueAI") as demo:
167
  task = gr.Dropdown(label="任务", show_label=True, choices=task_styles, value="标题生成文章")
168
  btn = gr.Button("生成",elem_id="gen_btn_1").style(full_width=False)
169
  with gr.Accordion("高级操作", open=False):
170
- pass
 
 
 
 
 
 
 
171
  with gr.Row(variant="compact").style( equal_height=True):
172
  output_text = gr.Textbox(
173
  label="输出", show_label=True, max_lines=50,
174
  placeholder="在这里展示结果",
175
  )
176
  gr.Examples(examples_list, [task, text], label="示例")
177
- input_params = [text, task]
178
  #text.submit(inference_gen, inputs=input_params, outputs=output_text)
179
  btn.click(inference_gen, inputs=input_params, outputs=output_text)
180
 
 
122
 
123
  return task_style_to_task_prefix[task] + "\n" + text + "\n答案:"
124
 
125
+ def inference_gen(text, task, do_sample, top_p, top_k, max_token, temperature, beam_size, length_penalty):
126
  text = preprocess(text, task)
127
+ generate_config = {
128
+ "do_sample": do_sample,
129
+ "top_p": top_p,
130
+ "top_k": top_k,
131
+ "max_length": max_token,
132
+ "temperature": temperature,
133
+ "num_beams": beam_size,
134
+ "length_penalty": length_penalty
135
+ }
136
+ #print(generate_config)
137
  #print(text)
138
  try:
139
  prediction = cl.generate(
140
  model_name='clueai-base',
141
+ prompt=text,
142
+ generate_config=generate_config)
143
  except Exception as e:
144
+ logger.error(f"error, {e}")
145
  return
146
 
147
  return prediction.generations[0].text
 
152
  from PIL import Image
153
  def inference_image(text, n_text, guidance_scale, style, shape, clarity, steps, shape_scale):
154
  try:
155
+ res = requests.get(f"https://www.clueai.cn/clueai/hf_text2image?text={text}&negative_prompt={n_text}\
156
+ &guidance_scale={guidance_scale}&num_inference_steps={steps}\
157
+ &style={style}&shape={shape}&clarity={clarity}&shape_scale={shape_scale}")
158
  except Exception as e:
159
  logger.error(f"error, {e}")
160
  return
 
180
  task = gr.Dropdown(label="任务", show_label=True, choices=task_styles, value="标题生成文章")
181
  btn = gr.Button("生成",elem_id="gen_btn_1").style(full_width=False)
182
  with gr.Accordion("高级操作", open=False):
183
+ do_sample = gr.Radio([True, False], label="是否采样", value=False)
184
+ top_p = gr.Slider(0, 1, value=0, step=0.1, label="越大多样性越高, 按照概率采样")
185
+ top_k = gr.Slider(1, 100, value=50, step=1, label="越大多样性越高,按照top k采样")
186
+ max_token = gr.Slider(1, 512, value=64, step=1, label="生成的最大长度")
187
+ temperature = gr.Slider(0,1, value=1, step=0.1, label="temperature, 越小下一个token预测概率越平滑")
188
+ beam_size = gr.Slider(1, 4, value=1, step=1, label="beam size, 越大解码窗口越广,")
189
+ length_penalty = gr.Slider(-1, 1, value=0.6, step=0.1, label="大于0鼓励长句子,小于0鼓励短句子")
190
+
191
  with gr.Row(variant="compact").style( equal_height=True):
192
  output_text = gr.Textbox(
193
  label="输出", show_label=True, max_lines=50,
194
  placeholder="在这里展示结果",
195
  )
196
  gr.Examples(examples_list, [task, text], label="示例")
197
+ input_params = [text, task, do_sample, top_p, top_k, max_token, temperature, beam_size, length_penalty]
198
  #text.submit(inference_gen, inputs=input_params, outputs=output_text)
199
  btn.click(inference_gen, inputs=input_params, outputs=output_text)
200