theodotus commited on
Commit
3f45785
·
1 Parent(s): 8521c6a

Added copy of pythia-uk

Browse files
Files changed (1) hide show
  1. app.py +95 -0
app.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctranslate2
2
+ from transformers import AutoTokenizer
3
+
4
+ import threading
5
+ import gradio as gr
6
+
7
+ from typing import Optional
8
+ from queue import Queue
9
+
10
+
11
+
12
+
13
+ class TokenIteratorStreamer:
14
+ def __init__(self, end_token_id: int, timeout: Optional[float] = None):
15
+ self.end_token_id = end_token_id
16
+ self.queue = Queue()
17
+ self.timeout = timeout
18
+
19
+ def put(self, token: int):
20
+ self.queue.put(token, timeout=self.timeout)
21
+
22
+ def __iter__(self):
23
+ return self
24
+
25
+ def __next__(self):
26
+ token = self.queue.get(timeout=self.timeout)
27
+ if token == self.end_token_id:
28
+ raise StopIteration()
29
+ else:
30
+ return token
31
+
32
+
33
+
34
+ def generate_prompt(history):
35
+ prompt = ""
36
+ for chain in history[:-1]:
37
+ prompt += f"<human>: {chain[0]}\n<bot>: {chain[1]}\n"
38
+ prompt += f"<human>: {history[-1][0]}\n<bot>:"
39
+ tokens = tokenizer.convert_ids_to_tokens(tokenizer.encode(prompt))
40
+ return tokens
41
+
42
+ def generate(streamer, history):
43
+ def stepResultCallback(result):
44
+ streamer.put(result.token_id)
45
+ if result.is_last and (result.token_id != end_token_id):
46
+ streamer.put(end_token_id)
47
+ print(f"step={result.step}, batch_id={result.batch_id}, token={result.token}")
48
+
49
+ tokens = generate_prompt(history)
50
+
51
+ results = translator.translate_batch(
52
+ [tokens],
53
+ beam_size=1,
54
+ max_decoding_length = 256,
55
+ repetition_penalty = 1.8,
56
+ callback = stepResultCallback
57
+ )
58
+ return results
59
+
60
+
61
+
62
+ translator = ctranslate2.Translator("model", intra_threads=2)
63
+ tokenizer = AutoTokenizer.from_pretrained("DKYoon/mt5-xl-lm-adapt")
64
+ end_token = "</s>"
65
+ end_token_id = tokenizer.encode(end_token)[0]
66
+
67
+
68
+ with gr.Blocks() as demo:
69
+ chatbot = gr.Chatbot()
70
+ msg = gr.Textbox()
71
+ clear = gr.Button("Clear")
72
+
73
+ def user(user_message, history):
74
+ return "", history + [[user_message, ""]]
75
+
76
+ def bot(history):
77
+ bot_message_tokens = []
78
+ streamer = TokenIteratorStreamer(end_token_id = end_token_id)
79
+ generation_thread = threading.Thread(target=generate, args=(streamer, history))
80
+ generation_thread.start()
81
+
82
+ for token in streamer:
83
+ bot_message_tokens.append(token)
84
+ history[-1][1] = tokenizer.decode(bot_message_tokens)
85
+ yield history
86
+ generation_thread.join()
87
+
88
+ msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
89
+ bot, chatbot, chatbot
90
+ )
91
+ clear.click(lambda: None, None, chatbot, queue=False)
92
+
93
+ demo.queue()
94
+ if __name__ == "__main__":
95
+ demo.launch()