[update]edit main
Browse files- .gitignore +1 -0
- main.py +18 -7
.gitignore
CHANGED
@@ -5,5 +5,6 @@
|
|
5 |
**/flagged/
|
6 |
**/__pycache__/
|
7 |
|
|
|
8 |
flagged/
|
9 |
trained_models/
|
|
|
5 |
**/flagged/
|
6 |
**/__pycache__/
|
7 |
|
8 |
+
cache/
|
9 |
flagged/
|
10 |
trained_models/
|
main.py
CHANGED
@@ -4,6 +4,11 @@ import argparse
|
|
4 |
from collections import defaultdict
|
5 |
import os
|
6 |
import platform
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
import gradio as gr
|
9 |
from threading import Thread
|
@@ -12,8 +17,6 @@ from transformers.models.bert.tokenization_bert import BertTokenizer
|
|
12 |
from transformers.generation.streamers import TextIteratorStreamer
|
13 |
import torch
|
14 |
|
15 |
-
from project_settings import project_path
|
16 |
-
|
17 |
|
18 |
def get_args():
|
19 |
parser = argparse.ArgumentParser()
|
@@ -38,6 +41,11 @@ examples = [
|
|
38 |
]
|
39 |
|
40 |
|
|
|
|
|
|
|
|
|
|
|
41 |
def main():
|
42 |
args = get_args()
|
43 |
|
@@ -94,13 +102,16 @@ def main():
|
|
94 |
if first_answer:
|
95 |
first_answer = False
|
96 |
continue
|
97 |
-
|
98 |
-
|
99 |
-
output_ = output_.replace("[SEP]", "\n")
|
100 |
output_ = output_.replace("[UNK]", "")
|
101 |
-
output_ = output_.replace(" ", "")
|
102 |
|
103 |
-
output += output_
|
|
|
|
|
|
|
|
|
|
|
104 |
output_text_box.value += output
|
105 |
yield output
|
106 |
|
|
|
4 |
from collections import defaultdict
|
5 |
import os
|
6 |
import platform
|
7 |
+
import re
|
8 |
+
|
9 |
+
from project_settings import project_path
|
10 |
+
|
11 |
+
os.environ["HUGGINGFACE_HUB_CACHE"] = (project_path / "cache/huggingface/hub").as_posix()
|
12 |
|
13 |
import gradio as gr
|
14 |
from threading import Thread
|
|
|
17 |
from transformers.generation.streamers import TextIteratorStreamer
|
18 |
import torch
|
19 |
|
|
|
|
|
20 |
|
21 |
def get_args():
|
22 |
parser = argparse.ArgumentParser()
|
|
|
41 |
]
|
42 |
|
43 |
|
44 |
+
def repl(match):
|
45 |
+
result = "{}{}".format(match.group(1), match.group(2))
|
46 |
+
return result
|
47 |
+
|
48 |
+
|
49 |
def main():
|
50 |
args = get_args()
|
51 |
|
|
|
102 |
if first_answer:
|
103 |
first_answer = False
|
104 |
continue
|
105 |
+
|
106 |
+
output_ = output_.replace("[UNK] ", "")
|
|
|
107 |
output_ = output_.replace("[UNK]", "")
|
|
|
108 |
|
109 |
+
output += output_
|
110 |
+
|
111 |
+
output = output.lstrip("[SEP] ,.!?")
|
112 |
+
output = output.replace("[SEP]", "\n")
|
113 |
+
output = re.sub(r"([\u4e00-\u9fa5]) ([\u4e00-\u9fa5])", repl, output)
|
114 |
+
|
115 |
output_text_box.value += output
|
116 |
yield output
|
117 |
|