File size: 11,033 Bytes
47289f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
import markdown
import re
import os
import math
from textwrap import dedent
from functools import lru_cache
from pymdownx.superfences import fence_div_format, fence_code_format
from latex2mathml.converter import convert as tex2mathml
from shared_utils.config_loader import get_conf as get_conf

pj = os.path.join
default_user_name = 'default_user'

markdown_extension_configs = {
    'mdx_math': {
        'enable_dollar_delimiter': True,
        'use_gitlab_delimiters': False,
    },
}

code_highlight_configs = {
    "pymdownx.superfences": {
        'css_class': 'codehilite',
        "custom_fences": [
            {
                'name': 'mermaid',
                'class': 'mermaid',
                'format': fence_code_format
            }
        ]
    },
    "pymdownx.highlight": {
        'css_class': 'codehilite',
        'guess_lang': True,
        # 'auto_title': True,
        # 'linenums': True
    }
}

def text_divide_paragraph(text):
    """
    将文本按照段落分隔符分割开,生成带有段落标签的HTML代码。
    """
    pre = '<div class="markdown-body">'
    suf = '</div>'
    if text.startswith(pre) and text.endswith(suf):
        return text

    if '```' in text:
        # careful input
        return text
    elif '</div>' in text:
        # careful input
        return text
    else:
        # whatever input
        lines = text.split("\n")
        for i, line in enumerate(lines):
            lines[i] = lines[i].replace(" ", "&nbsp;")
        text = "</br>".join(lines)
        return pre + text + suf


def tex2mathml_catch_exception(content, *args, **kwargs):
    try:
        content = tex2mathml(content, *args, **kwargs)
    except:
        content = content
    return content


def replace_math_no_render(match):
    content = match.group(1)
    if 'mode=display' in match.group(0):
        content = content.replace('\n', '</br>')
        return f"<font color=\"#00FF00\">$$</font><font color=\"#FF00FF\">{content}</font><font color=\"#00FF00\">$$</font>"
    else:
        return f"<font color=\"#00FF00\">$</font><font color=\"#FF00FF\">{content}</font><font color=\"#00FF00\">$</font>"


def replace_math_render(match):
    content = match.group(1)
    if 'mode=display' in match.group(0):
        if '\\begin{aligned}' in content:
            content = content.replace('\\begin{aligned}', '\\begin{array}')
            content = content.replace('\\end{aligned}', '\\end{array}')
            content = content.replace('&', ' ')
        content = tex2mathml_catch_exception(content, display="block")
        return content
    else:
        return tex2mathml_catch_exception(content)


def markdown_bug_hunt(content):
    """
    解决一个mdx_math的bug(单$包裹begin命令时多余<script>)
    """
    content = content.replace('<script type="math/tex">\n<script type="math/tex; mode=display">',
                                '<script type="math/tex; mode=display">')
    content = content.replace('</script>\n</script>', '</script>')
    return content


def is_equation(txt):
    """
    判定是否为公式 | 测试1 写出洛伦兹定律,使用tex格式公式 测试2 给出柯西不等式,使用latex格式 测试3 写出麦克斯韦方程组
    """
    if '```' in txt and '```reference' not in txt: return False
    if '$' not in txt and '\\[' not in txt: return False
    mathpatterns = {
        r'(?<!\\|\$)(\$)([^\$]+)(\$)': {'allow_multi_lines': False},                       #  $...$
        r'(?<!\\)(\$\$)([^\$]+)(\$\$)': {'allow_multi_lines': True},                       # $$...$$
        r'(?<!\\)(\\\[)(.+?)(\\\])': {'allow_multi_lines': False},                         # \[...\]
        # r'(?<!\\)(\\\()(.+?)(\\\))': {'allow_multi_lines': False},                       # \(...\)
        # r'(?<!\\)(\\begin{([a-z]+?\*?)})(.+?)(\\end{\2})': {'allow_multi_lines': True},  # \begin...\end
        # r'(?<!\\)(\$`)([^`]+)(`\$)': {'allow_multi_lines': False},                       # $`...`$
    }
    matches = []
    for pattern, property in mathpatterns.items():
        flags = re.ASCII | re.DOTALL if property['allow_multi_lines'] else re.ASCII
        matches.extend(re.findall(pattern, txt, flags))
    if len(matches) == 0: return False
    contain_any_eq = False
    illegal_pattern = re.compile(r'[^\x00-\x7F]|echo')
    for match in matches:
        if len(match) != 3: return False
        eq_canidate = match[1]
        if illegal_pattern.search(eq_canidate):
            return False
        else:
            contain_any_eq = True
    return contain_any_eq


def fix_markdown_indent(txt):
    # fix markdown indent
    if (' - ' not in txt) or ('. ' not in txt):
        # do not need to fix, fast escape
        return txt
    # walk through the lines and fix non-standard indentation
    lines = txt.split("\n")
    pattern = re.compile(r'^\s+-')
    activated = False
    for i, line in enumerate(lines):
        if line.startswith('- ') or line.startswith('1. '):
            activated = True
        if activated and pattern.match(line):
            stripped_string = line.lstrip()
            num_spaces = len(line) - len(stripped_string)
            if (num_spaces % 4) == 3:
                num_spaces_should_be = math.ceil(num_spaces / 4) * 4
                lines[i] = ' ' * num_spaces_should_be + stripped_string
    return '\n'.join(lines)


FENCED_BLOCK_RE = re.compile(
    dedent(r'''
        (?P<fence>^[ \t]*(?:~{3,}|`{3,}))[ ]*                      # opening fence
        ((\{(?P<attrs>[^\}\n]*)\})|                              # (optional {attrs} or
        (\.?(?P<lang>[\w#.+-]*)[ ]*)?                            # optional (.)lang
        (hl_lines=(?P<quot>"|')(?P<hl_lines>.*?)(?P=quot)[ ]*)?) # optional hl_lines)
        \n                                                       # newline (end of opening fence)
        (?P<code>.*?)(?<=\n)                                     # the code block
        (?P=fence)[ ]*$                                          # closing fence
    '''),
    re.MULTILINE | re.DOTALL | re.VERBOSE
)


def get_line_range(re_match_obj, txt):
    start_pos, end_pos = re_match_obj.regs[0]
    num_newlines_before = txt[:start_pos+1].count('\n')
    line_start = num_newlines_before
    line_end = num_newlines_before + txt[start_pos:end_pos].count('\n')+1
    return line_start, line_end


def fix_code_segment_indent(txt):
    lines = []
    change_any = False
    txt_tmp = txt
    while True:
        re_match_obj = FENCED_BLOCK_RE.search(txt_tmp)
        if not re_match_obj: break
        if len(lines) == 0: lines = txt.split("\n")
        
        # 清空 txt_tmp 对应的位置方便下次搜索
        start_pos, end_pos = re_match_obj.regs[0]
        txt_tmp = txt_tmp[:start_pos] + ' '*(end_pos-start_pos) + txt_tmp[end_pos:]
        line_start, line_end = get_line_range(re_match_obj, txt)
        
        # 获取公共缩进
        shared_indent_cnt = 1e5
        for i in range(line_start, line_end):
            stripped_string = lines[i].lstrip()
            num_spaces = len(lines[i]) - len(stripped_string)
            if num_spaces < shared_indent_cnt:
                shared_indent_cnt = num_spaces

        # 修复缩进
        if (shared_indent_cnt < 1e5) and (shared_indent_cnt % 4) == 3:
            num_spaces_should_be = math.ceil(shared_indent_cnt / 4) * 4
            for i in range(line_start, line_end):
                add_n = num_spaces_should_be - shared_indent_cnt
                lines[i] = ' ' * add_n + lines[i]
            if not change_any: # 遇到第一个
                change_any = True

    if change_any:
        return '\n'.join(lines)
    else:
        return txt
    
    
@lru_cache(maxsize=128) # 使用 lru缓存 加快转换速度
def markdown_convertion(txt):
    """
    将Markdown格式的文本转换为HTML格式。如果包含数学公式,则先将公式转换为HTML格式。
    """
    pre = '<div class="markdown-body">'
    suf = '</div>'
    if txt.startswith(pre) and txt.endswith(suf):
        # print('警告,输入了已经经过转化的字符串,二次转化可能出问题')
        return txt # 已经被转化过,不需要再次转化

    find_equation_pattern = r'<script type="math/tex(?:.*?)>(.*?)</script>'

    txt = fix_markdown_indent(txt)
    # txt = fix_code_segment_indent(txt)
    if is_equation(txt):  # 有$标识的公式符号,且没有代码段```的标识
        # convert everything to html format
        split = markdown.markdown(text='---')
        convert_stage_1 = markdown.markdown(text=txt, extensions=['sane_lists', 'tables', 'mdx_math', 'pymdownx.superfences', 'pymdownx.highlight'],
                                            extension_configs={**markdown_extension_configs, **code_highlight_configs})
        convert_stage_1 = markdown_bug_hunt(convert_stage_1)
        # 1. convert to easy-to-copy tex (do not render math)
        convert_stage_2_1, n = re.subn(find_equation_pattern, replace_math_no_render, convert_stage_1, flags=re.DOTALL)
        # 2. convert to rendered equation
        convert_stage_2_2, n = re.subn(find_equation_pattern, replace_math_render, convert_stage_1, flags=re.DOTALL)
        # cat them together
        return pre + convert_stage_2_1 + f'{split}' + convert_stage_2_2 + suf
    else:
        return pre + markdown.markdown(txt, extensions=['sane_lists', 'tables', 'pymdownx.superfences', 'pymdownx.highlight'], extension_configs=code_highlight_configs) + suf


def close_up_code_segment_during_stream(gpt_reply):
    """
    在gpt输出代码的中途(输出了前面的```,但还没输出完后面的```),补上后面的```

    Args:
        gpt_reply (str): GPT模型返回的回复字符串。

    Returns:
        str: 返回一个新的字符串,将输出代码片段的“后面的```”补上。

    """
    if '```' not in gpt_reply:
        return gpt_reply
    if gpt_reply.endswith('```'):
        return gpt_reply

    # 排除了以上两个情况,我们
    segments = gpt_reply.split('```')
    n_mark = len(segments) - 1
    if n_mark % 2 == 1:
        return gpt_reply + '\n```' # 输出代码片段中!
    else:
        return gpt_reply


def format_io(self, y):
    """
    将输入和输出解析为HTML格式。将y中最后一项的输入部分段落化,并将输出部分的Markdown和数学公式转换为HTML格式。
    """
    if y is None or y == []:
        return []
    i_ask, gpt_reply = y[-1]
    # 输入部分太自由,预处理一波
    if i_ask is not None: i_ask = text_divide_paragraph(i_ask)
    # 当代码输出半截的时候,试着补上后个```
    if gpt_reply is not None: gpt_reply = close_up_code_segment_during_stream(gpt_reply)
    # process
    y[-1] = (
        None if i_ask is None else markdown.markdown(i_ask, extensions=['pymdownx.superfences', 'tables', 'pymdownx.highlight'], extension_configs=code_highlight_configs),
        None if gpt_reply is None else markdown_convertion(gpt_reply)
    )
    return y