tongxiaojun commited on
Commit
4c78b8e
·
1 Parent(s): feea4c1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -1
app.py CHANGED
@@ -2,13 +2,32 @@ import gradio as gr
2
  import random
3
  import time
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  with gr.Blocks() as demo:
6
  chatbot = gr.Chatbot()
7
  msg = gr.Textbox()
8
  clear = gr.Button("Clear")
9
 
10
  def respond(message, chat_history):
11
- bot_message = random.choice(["How are you?", "I love you", "I'm very hungry"])
 
12
  chat_history.append((message, bot_message))
13
  time.sleep(1)
14
  return "", chat_history
 
2
  import random
3
  import time
4
 
5
+ from transformers import BloomTokenizerFast, BloomForCausalLM
6
+ path = 'YeungNLP/firefly-2b6-v2'
7
+
8
+ tokenizer = BloomTokenizerFast.from_pretrained(path)
9
+ model = BloomForCausalLM.from_pretrained(path)
10
+ model.eval()
11
+
12
+
13
+ def generate(text):
14
+ text = '<s>{}</s></s>'.format(text)
15
+ input_ids = tokenizer(text, return_tensors="pt").input_ids
16
+ #input_ids = input_ids.to(device)
17
+ outputs = model.generate(input_ids, max_new_tokens=200, do_sample=True, top_p=0.7, temperature=0.35,
18
+ repetition_penalty=1.2, eos_token_id=tokenizer.eos_token_id)
19
+ rets = tokenizer.batch_decode(outputs)
20
+ output = rets[0].strip().replace(text, "").replace('</s>', "")
21
+ return output
22
+
23
  with gr.Blocks() as demo:
24
  chatbot = gr.Chatbot()
25
  msg = gr.Textbox()
26
  clear = gr.Button("Clear")
27
 
28
  def respond(message, chat_history):
29
+ #bot_message = random.choice(["How are you?", "I love you", "I'm very hungry"])
30
+ bot_message = generate(message)
31
  chat_history.append((message, bot_message))
32
  time.sleep(1)
33
  return "", chat_history