File size: 2,976 Bytes
8a41f4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import re
from math import ceil
from typing import List

import tiktoken
from application.parser.schema.base import Document


def separate_header_and_body(text):
    header_pattern = r"^(.*?\n){3}"
    match = re.match(header_pattern, text)
    header = match.group(0)
    body = text[len(header):]
    return header, body


def group_documents(documents: List[Document], min_tokens: int, max_tokens: int) -> List[Document]:
    docs = []
    current_group = None

    for doc in documents:
        doc_len = len(tiktoken.get_encoding("cl100k_base").encode(doc.text))

        if current_group is None:
            current_group = Document(text=doc.text, doc_id=doc.doc_id, embedding=doc.embedding,
                                     extra_info=doc.extra_info)
        elif len(tiktoken.get_encoding("cl100k_base").encode(
                current_group.text)) + doc_len < max_tokens and doc_len < min_tokens:
            current_group.text += " " + doc.text
        else:
            docs.append(current_group)
            current_group = Document(text=doc.text, doc_id=doc.doc_id, embedding=doc.embedding,
                                     extra_info=doc.extra_info)

    if current_group is not None:
        docs.append(current_group)

    return docs


def split_documents(documents: List[Document], max_tokens: int) -> List[Document]:
    docs = []
    for doc in documents:
        token_length = len(tiktoken.get_encoding("cl100k_base").encode(doc.text))
        if token_length <= max_tokens:
            docs.append(doc)
        else:
            header, body = separate_header_and_body(doc.text)
            if len(tiktoken.get_encoding("cl100k_base").encode(header)) > max_tokens:
                body = doc.text
                header = ""
            num_body_parts = ceil(token_length / max_tokens)
            part_length = ceil(len(body) / num_body_parts)
            body_parts = [body[i:i + part_length] for i in range(0, len(body), part_length)]
            for i, body_part in enumerate(body_parts):
                new_doc = Document(text=header + body_part.strip(),
                                   doc_id=f"{doc.doc_id}-{i}",
                                   embedding=doc.embedding,
                                   extra_info=doc.extra_info)
                docs.append(new_doc)
    return docs


def group_split(documents: List[Document], max_tokens: int = 2000, min_tokens: int = 150, token_check: bool = True):
    if not token_check:
        return documents
    print("Grouping small documents")
    try:
        documents = group_documents(documents=documents, min_tokens=min_tokens, max_tokens=max_tokens)
    except Exception:
        print("Grouping failed, try running without token_check")
    print("Separating large documents")
    try:
        documents = split_documents(documents=documents, max_tokens=max_tokens)
    except Exception:
        print("Grouping failed, try running without token_check")
    return documents