Jasonai commited on
Commit
a360cd7
·
1 Parent(s): 06a04f1

feat(支持rar格式与7z格式解压)

Browse files
Files changed (2) hide show
  1. requirements.txt +9 -2
  2. toolbox.py +67 -24
requirements.txt CHANGED
@@ -1,3 +1,10 @@
1
  gradio>=3.23
2
- requests[socks]
3
- mdtex2html
 
 
 
 
 
 
 
 
1
  gradio>=3.23
2
+ requests[socks]~=2.28.2
3
+ mdtex2html~=1.2.0
4
+
5
+ markdown~=3.4.3
6
+ latex2mathml~=3.75.1
7
+ numpy~=1.21.6
8
+
9
+ rarfile~=4.0
10
+ py7zr~=0.20.4
toolbox.py CHANGED
@@ -2,6 +2,7 @@ import markdown, mdtex2html, threading, importlib, traceback
2
  from show_math import convert as convert_math
3
  from functools import wraps
4
 
 
5
  def predict_no_ui_but_counting_down(i_say, i_say_show_user, chatbot, top_p, temperature, history=[], sys_prompt=''):
6
  """
7
  调用简单的predict_no_ui接口,但是依然保留了些许界面心跳功能,当对话太长时,会自动采用二分法截断
@@ -13,36 +14,43 @@ def predict_no_ui_but_counting_down(i_say, i_say_show_user, chatbot, top_p, temp
13
  # 多线程的时候,需要一个mutable结构在不同线程之间传递信息
14
  # list就是最简单的mutable结构,我们第一个位置放gpt输出,第二个位置传递报错信息
15
  mutable = [None, '']
 
16
  # multi-threading worker
17
  def mt(i_say, history):
18
  while True:
19
  try:
20
- mutable[0] = predict_no_ui(inputs=i_say, top_p=top_p, temperature=temperature, history=history, sys_prompt=sys_prompt)
 
21
  break
22
  except ConnectionAbortedError as e:
23
  if len(history) > 0:
24
- history = [his[len(his)//2:] for his in history if his is not None]
25
  mutable[1] = 'Warning! History conversation is too long, cut into half. '
26
  else:
27
- i_say = i_say[:len(i_say)//2]
28
  mutable[1] = 'Warning! Input file is too long, cut into half. '
29
  except TimeoutError as e:
30
  mutable[0] = '[Local Message] Failed with timeout.'
31
  raise TimeoutError
 
32
  # 创建新线程发出http请求
33
- thread_name = threading.Thread(target=mt, args=(i_say, history)); thread_name.start()
 
34
  # 原来的线程则负责持续更新UI,实现一个超时倒计时,并等待新线程的任务完成
35
  cnt = 0
36
  while thread_name.is_alive():
37
  cnt += 1
38
- chatbot[-1] = (i_say_show_user, f"[Local Message] {mutable[1]}waiting gpt response {cnt}/{TIMEOUT_SECONDS*2*(MAX_RETRY+1)}"+''.join(['.']*(cnt%4)))
 
 
39
  yield chatbot, history, '正常'
40
  time.sleep(1)
41
  # 把gpt的输出从mutable中取出来
42
  gpt_say = mutable[0]
43
- if gpt_say=='[Local Message] Failed with timeout.': raise TimeoutError
44
  return gpt_say
45
 
 
46
  def write_results_to_file(history, file_name=None):
47
  """
48
  将对话记录history以Markdown格式写入文件中。如果没有指定文件名,则使用当前时间生成文件名。
@@ -52,16 +60,17 @@ def write_results_to_file(history, file_name=None):
52
  # file_name = time.strftime("chatGPT分析报告%Y-%m-%d-%H-%M-%S", time.localtime()) + '.md'
53
  file_name = 'chatGPT分析报告' + time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) + '.md'
54
  os.makedirs('./gpt_log/', exist_ok=True)
55
- with open(f'./gpt_log/{file_name}', 'w', encoding = 'utf8') as f:
56
  f.write('# chatGPT 分析报告\n')
57
  for i, content in enumerate(history):
58
- if i%2==0: f.write('## ')
59
  f.write(content)
60
  f.write('\n\n')
61
  res = '以上材料已经被写入' + os.path.abspath(f'./gpt_log/{file_name}')
62
  print(res)
63
  return res
64
 
 
65
  def regular_txt_to_markdown(text):
66
  """
67
  将普通文本转换为Markdown格式的文本。
@@ -71,10 +80,12 @@ def regular_txt_to_markdown(text):
71
  text = text.replace('\n\n\n', '\n\n')
72
  return text
73
 
 
74
  def CatchException(f):
75
  """
76
  装饰器函数,捕捉函数f中的异常并封装到一个生成器中返回,并显示到聊天当中。
77
  """
 
78
  @wraps(f)
79
  def decorated(txt, top_p, temperature, chatbot, history, systemPromptTxt, WEB_PORT):
80
  try:
@@ -84,16 +95,21 @@ def CatchException(f):
84
  from toolbox import get_conf
85
  proxies, = get_conf('proxies')
86
  tb_str = regular_txt_to_markdown(traceback.format_exc())
87
- chatbot[-1] = (chatbot[-1][0], f"[Local Message] 实验性函数调用出错: \n\n {tb_str} \n\n 当前代理可用性: \n\n {check_proxy(proxies)}")
 
88
  yield chatbot, history, f'异常 {e}'
 
89
  return decorated
90
 
 
91
  def report_execption(chatbot, history, a, b):
92
  """
93
  向chatbot中添加错误信息
94
  """
95
  chatbot.append((a, b))
96
- history.append(a); history.append(b)
 
 
97
 
98
  def text_divide_paragraph(text):
99
  """
@@ -110,15 +126,16 @@ def text_divide_paragraph(text):
110
  text = "</br>".join(lines)
111
  return text
112
 
 
113
  def markdown_convertion(txt):
114
  """
115
  将Markdown格式的文本转换为HTML格式。如果包含数学公式,则先将公式转换为HTML格式。
116
  """
117
  if ('$' in txt) and ('```' not in txt):
118
- return markdown.markdown(txt,extensions=['fenced_code','tables']) + '<br><br>' + \
119
- markdown.markdown(convert_math(txt, splitParagraphs=False),extensions=['fenced_code','tables'])
120
  else:
121
- return markdown.markdown(txt,extensions=['fenced_code','tables'])
122
 
123
 
124
  def format_io(self, y):
@@ -127,9 +144,9 @@ def format_io(self, y):
127
  """
128
  if y is None or y == []: return []
129
  i_ask, gpt_reply = y[-1]
130
- i_ask = text_divide_paragraph(i_ask) # 输入部分太自由,预处理一波
131
  y[-1] = (
132
- None if i_ask is None else markdown.markdown(i_ask, extensions=['fenced_code','tables']),
133
  None if gpt_reply is None else markdown_convertion(gpt_reply)
134
  )
135
  return y
@@ -151,6 +168,7 @@ def extract_archive(file_path, dest_dir):
151
  import zipfile
152
  import tarfile
153
  import os
 
154
  # Get the file extension of the input file
155
  file_extension = os.path.splitext(file_path)[1]
156
 
@@ -164,9 +182,28 @@ def extract_archive(file_path, dest_dir):
164
  with tarfile.open(file_path, 'r:*') as tarobj:
165
  tarobj.extractall(path=dest_dir)
166
  print("Successfully extracted tar archive to {}".format(dest_dir))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  else:
168
  return
169
 
 
170
  def find_recent_files(directory):
171
  """
172
  me: find files that is created with in one minutes under a directory with python, write a function
@@ -193,19 +230,21 @@ def on_file_uploaded(files, chatbot, txt):
193
  if len(files) == 0: return chatbot, txt
194
  import shutil, os, time, glob
195
  from toolbox import extract_archive
196
- try: shutil.rmtree('./private_upload/')
197
- except: pass
 
 
198
  time_tag = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
199
  os.makedirs(f'private_upload/{time_tag}', exist_ok=True)
200
  for file in files:
201
  file_origin_name = os.path.basename(file.orig_name)
202
  shutil.copy(file.name, f'private_upload/{time_tag}/{file_origin_name}')
203
- extract_archive(f'private_upload/{time_tag}/{file_origin_name}',
204
  dest_dir=f'private_upload/{time_tag}/{file_origin_name}.extract')
205
  moved_files = [fp for fp in glob.glob('private_upload/**/*', recursive=True)]
206
  txt = f'private_upload/{time_tag}'
207
  moved_files_str = '\t\n\n'.join(moved_files)
208
- chatbot.append(['我上传了文件,请查收',
209
  f'[Local Message] 收到以下文件: \n\n{moved_files_str}\n\n调用路径参数已自动修正到: \n\n{txt}\n\n现在您点击任意实验功能时,以上文件将被作为输入参数'])
210
  return chatbot, txt
211
 
@@ -218,21 +257,25 @@ def on_report_generated(files, chatbot):
218
  chatbot.append(['汇总报告如何远程获取?', '汇总报告已经添加到右侧文件上传区,请查收。'])
219
  return report_files, chatbot
220
 
 
221
  def get_conf(*args):
222
  # 建议您复制一个config_private.py放自己的秘密, 如API和代理网址, 避免不小心传github被别人看到
223
  res = []
224
  for arg in args:
225
- try: r = getattr(importlib.import_module('config_private'), arg)
226
- except: r = getattr(importlib.import_module('config'), arg)
 
 
227
  res.append(r)
228
  # 在读取API_KEY时,检查一下是不是忘了改config
229
- if arg=='API_KEY' and len(r) != 51:
230
  assert False, "正确的API_KEY密钥是51位,请在config文件中修改API密钥, 添加海外代理之后再运行。" + \
231
- "(如果您刚更新过代码,请确保旧版config_private文件中没有遗留任何新增键值)"
232
  return res
233
 
 
234
  def clear_line_break(txt):
235
  txt = txt.replace('\n', ' ')
236
  txt = txt.replace(' ', ' ')
237
  txt = txt.replace(' ', ' ')
238
- return txt
 
2
  from show_math import convert as convert_math
3
  from functools import wraps
4
 
5
+
6
  def predict_no_ui_but_counting_down(i_say, i_say_show_user, chatbot, top_p, temperature, history=[], sys_prompt=''):
7
  """
8
  调用简单的predict_no_ui接口,但是依然保留了些许界面心跳功能,当对话太长时,会自动采用二分法截断
 
14
  # 多线程的时候,需要一个mutable结构在不同线程之间传递信息
15
  # list就是最简单的mutable结构,我们第一个位置放gpt输出,第二个位置传递报错信息
16
  mutable = [None, '']
17
+
18
  # multi-threading worker
19
  def mt(i_say, history):
20
  while True:
21
  try:
22
+ mutable[0] = predict_no_ui(inputs=i_say, top_p=top_p, temperature=temperature, history=history,
23
+ sys_prompt=sys_prompt)
24
  break
25
  except ConnectionAbortedError as e:
26
  if len(history) > 0:
27
+ history = [his[len(his) // 2:] for his in history if his is not None]
28
  mutable[1] = 'Warning! History conversation is too long, cut into half. '
29
  else:
30
+ i_say = i_say[:len(i_say) // 2]
31
  mutable[1] = 'Warning! Input file is too long, cut into half. '
32
  except TimeoutError as e:
33
  mutable[0] = '[Local Message] Failed with timeout.'
34
  raise TimeoutError
35
+
36
  # 创建新线程发出http请求
37
+ thread_name = threading.Thread(target=mt, args=(i_say, history));
38
+ thread_name.start()
39
  # 原来的线程则负责持续更新UI,实现一个超时倒计时,并等待新线程的任务完成
40
  cnt = 0
41
  while thread_name.is_alive():
42
  cnt += 1
43
+ chatbot[-1] = (i_say_show_user,
44
+ f"[Local Message] {mutable[1]}waiting gpt response {cnt}/{TIMEOUT_SECONDS * 2 * (MAX_RETRY + 1)}" + ''.join(
45
+ ['.'] * (cnt % 4)))
46
  yield chatbot, history, '正常'
47
  time.sleep(1)
48
  # 把gpt的输出从mutable中取出来
49
  gpt_say = mutable[0]
50
+ if gpt_say == '[Local Message] Failed with timeout.': raise TimeoutError
51
  return gpt_say
52
 
53
+
54
  def write_results_to_file(history, file_name=None):
55
  """
56
  将对话记录history以Markdown格式写入文件中。如果没有指定文件名,则使用当前时间生成文件名。
 
60
  # file_name = time.strftime("chatGPT分析报告%Y-%m-%d-%H-%M-%S", time.localtime()) + '.md'
61
  file_name = 'chatGPT分析报告' + time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) + '.md'
62
  os.makedirs('./gpt_log/', exist_ok=True)
63
+ with open(f'./gpt_log/{file_name}', 'w', encoding='utf8') as f:
64
  f.write('# chatGPT 分析报告\n')
65
  for i, content in enumerate(history):
66
+ if i % 2 == 0: f.write('## ')
67
  f.write(content)
68
  f.write('\n\n')
69
  res = '以上材料已经被写入' + os.path.abspath(f'./gpt_log/{file_name}')
70
  print(res)
71
  return res
72
 
73
+
74
  def regular_txt_to_markdown(text):
75
  """
76
  将普通文本转换为Markdown格式的文本。
 
80
  text = text.replace('\n\n\n', '\n\n')
81
  return text
82
 
83
+
84
  def CatchException(f):
85
  """
86
  装饰器函数,捕捉函数f中的异常并封装到一个生成器中返回,并显示到聊天当中。
87
  """
88
+
89
  @wraps(f)
90
  def decorated(txt, top_p, temperature, chatbot, history, systemPromptTxt, WEB_PORT):
91
  try:
 
95
  from toolbox import get_conf
96
  proxies, = get_conf('proxies')
97
  tb_str = regular_txt_to_markdown(traceback.format_exc())
98
+ chatbot[-1] = (
99
+ chatbot[-1][0], f"[Local Message] 实验性函数调用出错: \n\n {tb_str} \n\n 当前代理可用性: \n\n {check_proxy(proxies)}")
100
  yield chatbot, history, f'异常 {e}'
101
+
102
  return decorated
103
 
104
+
105
  def report_execption(chatbot, history, a, b):
106
  """
107
  向chatbot中添加错误信息
108
  """
109
  chatbot.append((a, b))
110
+ history.append(a);
111
+ history.append(b)
112
+
113
 
114
  def text_divide_paragraph(text):
115
  """
 
126
  text = "</br>".join(lines)
127
  return text
128
 
129
+
130
  def markdown_convertion(txt):
131
  """
132
  将Markdown格式的文本转换为HTML格式。如果包含数学公式,则先将公式转换为HTML格式。
133
  """
134
  if ('$' in txt) and ('```' not in txt):
135
+ return markdown.markdown(txt, extensions=['fenced_code', 'tables']) + '<br><br>' + \
136
+ markdown.markdown(convert_math(txt, splitParagraphs=False), extensions=['fenced_code', 'tables'])
137
  else:
138
+ return markdown.markdown(txt, extensions=['fenced_code', 'tables'])
139
 
140
 
141
  def format_io(self, y):
 
144
  """
145
  if y is None or y == []: return []
146
  i_ask, gpt_reply = y[-1]
147
+ i_ask = text_divide_paragraph(i_ask) # 输入部分太自由,预处理一波
148
  y[-1] = (
149
+ None if i_ask is None else markdown.markdown(i_ask, extensions=['fenced_code', 'tables']),
150
  None if gpt_reply is None else markdown_convertion(gpt_reply)
151
  )
152
  return y
 
168
  import zipfile
169
  import tarfile
170
  import os
171
+
172
  # Get the file extension of the input file
173
  file_extension = os.path.splitext(file_path)[1]
174
 
 
182
  with tarfile.open(file_path, 'r:*') as tarobj:
183
  tarobj.extractall(path=dest_dir)
184
  print("Successfully extracted tar archive to {}".format(dest_dir))
185
+
186
+ elif file_extension == '.rar':
187
+ # 这是个第三方库,需要预先pip install rarfile
188
+ # 此外,Windows上还需要安装winrar软件,配置其Path环境变量,如"C:\Program Files\WinRAR"才可以正常运行
189
+ try:
190
+ import rarfile
191
+ with rarfile.RarFile(file_path) as rf:
192
+ rf.extractall(path=dest_dir)
193
+ print("Successfully extracted rar archive to {}".format(dest_dir))
194
+ except:
195
+ print("rar格式需要安装额外依赖")
196
+ elif file_extension == '.7z':
197
+ try:
198
+ import py7zr
199
+ with py7zr.SevenZipFile(file_path, mode='r') as f:
200
+ f.extractall(path=dest_dir)
201
+ except:
202
+ print("7z格式需要安装额外依赖")
203
  else:
204
  return
205
 
206
+
207
  def find_recent_files(directory):
208
  """
209
  me: find files that is created with in one minutes under a directory with python, write a function
 
230
  if len(files) == 0: return chatbot, txt
231
  import shutil, os, time, glob
232
  from toolbox import extract_archive
233
+ try:
234
+ shutil.rmtree('./private_upload/')
235
+ except:
236
+ pass
237
  time_tag = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
238
  os.makedirs(f'private_upload/{time_tag}', exist_ok=True)
239
  for file in files:
240
  file_origin_name = os.path.basename(file.orig_name)
241
  shutil.copy(file.name, f'private_upload/{time_tag}/{file_origin_name}')
242
+ extract_archive(f'private_upload/{time_tag}/{file_origin_name}',
243
  dest_dir=f'private_upload/{time_tag}/{file_origin_name}.extract')
244
  moved_files = [fp for fp in glob.glob('private_upload/**/*', recursive=True)]
245
  txt = f'private_upload/{time_tag}'
246
  moved_files_str = '\t\n\n'.join(moved_files)
247
+ chatbot.append(['我上传了文件,请查收',
248
  f'[Local Message] 收到以下文件: \n\n{moved_files_str}\n\n调用路径参数已自动修正到: \n\n{txt}\n\n现在您点击任意实验功能时,以上文件将被作为输入参数'])
249
  return chatbot, txt
250
 
 
257
  chatbot.append(['汇总报告如何远程获取?', '汇总报告已经添加到右侧文件上传区,请查收。'])
258
  return report_files, chatbot
259
 
260
+
261
  def get_conf(*args):
262
  # 建议您复制一个config_private.py放自己的秘密, 如API和代理网址, 避免不小心传github被别人看到
263
  res = []
264
  for arg in args:
265
+ try:
266
+ r = getattr(importlib.import_module('config_private'), arg)
267
+ except:
268
+ r = getattr(importlib.import_module('config'), arg)
269
  res.append(r)
270
  # 在读取API_KEY时,检查一下是不是忘了改config
271
+ if arg == 'API_KEY' and len(r) != 51:
272
  assert False, "正确的API_KEY密钥是51位,请在config文件中修改API密钥, 添加海外代理之后再运行。" + \
273
+ "(如果您刚更新过代码,请确保旧版config_private文件中没有遗留任何新增键值)"
274
  return res
275
 
276
+
277
  def clear_line_break(txt):
278
  txt = txt.replace('\n', ' ')
279
  txt = txt.replace(' ', ' ')
280
  txt = txt.replace(' ', ' ')
281
+ return txt