svjack commited on
Commit
da582ee
·
1 Parent(s): 9ea124a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -0
app.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from predict import *
2
+ from reconstructor import *
3
+ from transformers import BertTokenizer, GPT2LMHeadModel
4
+
5
+ import os
6
+ import gradio as gr
7
+
8
+ model_path = "svjack/gpt-daliy-dialogue"
9
+ tokenizer = BertTokenizer.from_pretrained(model_path)
10
+ model = GPT2LMHeadModel.from_pretrained(model_path)
11
+
12
+ obj = Obj(model, tokenizer)
13
+
14
+ example_sample = [
15
+ ["这只狗很凶,", 128],
16
+ ["你饿吗?", 128],
17
+ ]
18
+
19
+ def demo_func(prefix, max_length):
20
+ max_length = max(int(max_length), 32)
21
+ x = obj.predict(prefix, max_length=max_length)[0]
22
+ y = list(map(lambda x: "".join(x).replace(" ", ""),batch_as_list(re.split(r"([。.??])" ,x), 2)))
23
+ l = predict_split(y)
24
+ assert type(l) == type([])
25
+ return {
26
+ "Dialogue Context": l
27
+ }
28
+
29
+ demo = gr.Interface(
30
+ fn=demo_func,
31
+ inputs=[gr.Text(label = "Prefix"),
32
+ gr.Number(label = "Max Length", value = 128)
33
+ ],
34
+ outputs="json",
35
+ title=f"GPT Chinese Daliy Dialogue Generator 🐰 demonstration",
36
+ examples=example_sample if example_sample else None,
37
+ cache_examples = False
38
+ )
39
+
40
+ demo.launch(server_name=None, server_port=None)