qgyd2021 commited on
Commit
341b916
1 Parent(s): d789a59

[update]edit main

Browse files
Files changed (2) hide show
  1. .gitignore +1 -0
  2. main.py +18 -7
.gitignore CHANGED
@@ -5,5 +5,6 @@
5
  **/flagged/
6
  **/__pycache__/
7
 
 
8
  flagged/
9
  trained_models/
 
5
  **/flagged/
6
  **/__pycache__/
7
 
8
+ cache/
9
  flagged/
10
  trained_models/
main.py CHANGED
@@ -4,6 +4,11 @@ import argparse
4
  from collections import defaultdict
5
  import os
6
  import platform
 
 
 
 
 
7
 
8
  import gradio as gr
9
  from threading import Thread
@@ -12,8 +17,6 @@ from transformers.models.bert.tokenization_bert import BertTokenizer
12
  from transformers.generation.streamers import TextIteratorStreamer
13
  import torch
14
 
15
- from project_settings import project_path
16
-
17
 
18
  def get_args():
19
  parser = argparse.ArgumentParser()
@@ -38,6 +41,11 @@ examples = [
38
  ]
39
 
40
 
 
 
 
 
 
41
  def main():
42
  args = get_args()
43
 
@@ -94,13 +102,16 @@ def main():
94
  if first_answer:
95
  first_answer = False
96
  continue
97
- # output_ = output_.replace(text, "")
98
- # output_ = output_.replace("[CLS]", "")
99
- output_ = output_.replace("[SEP]", "\n")
100
  output_ = output_.replace("[UNK]", "")
101
- output_ = output_.replace(" ", "")
102
 
103
- output += output_.strip()
 
 
 
 
 
104
  output_text_box.value += output
105
  yield output
106
 
 
4
  from collections import defaultdict
5
  import os
6
  import platform
7
+ import re
8
+
9
+ from project_settings import project_path
10
+
11
+ os.environ["HUGGINGFACE_HUB_CACHE"] = (project_path / "cache/huggingface/hub").as_posix()
12
 
13
  import gradio as gr
14
  from threading import Thread
 
17
  from transformers.generation.streamers import TextIteratorStreamer
18
  import torch
19
 
 
 
20
 
21
  def get_args():
22
  parser = argparse.ArgumentParser()
 
41
  ]
42
 
43
 
44
+ def repl(match):
45
+ result = "{}{}".format(match.group(1), match.group(2))
46
+ return result
47
+
48
+
49
  def main():
50
  args = get_args()
51
 
 
102
  if first_answer:
103
  first_answer = False
104
  continue
105
+
106
+ output_ = output_.replace("[UNK] ", "")
 
107
  output_ = output_.replace("[UNK]", "")
 
108
 
109
+ output += output_
110
+
111
+ output = output.lstrip("[SEP] ,.!?")
112
+ output = output.replace("[SEP]", "\n")
113
+ output = re.sub(r"([\u4e00-\u9fa5]) ([\u4e00-\u9fa5])", repl, output)
114
+
115
  output_text_box.value += output
116
  yield output
117