AlexanderKazakov
commited on
Commit
•
eeafaaa
1
Parent(s):
d7fdb42
small improvement for chunking; openai embeddings
Browse files- README initial.md +1 -1
- gradio_app/app.py +8 -6
- gradio_app/backend/embedders.py +43 -0
- gradio_app/backend/query_llm.py +7 -9
- gradio_app/backend/semantic_search.py +2 -2
- gradio_app/templates/context_html_template.j2 +5 -65
- prep_scripts/lancedb_setup.py +12 -26
- prep_scripts/markdown_to_text.py +17 -14
- settings.py +10 -6
README initial.md
CHANGED
@@ -13,7 +13,7 @@ Deliberately stripped down to leave some room for experimenting
|
|
13 |
- TODOs:
|
14 |
- Experiment with chunking, see how it affects the results. When deciding how to chunk it helps to think about what kind of chunks you'd like to see as context to your queries.
|
15 |
- Deliverables: Demonstrate how retrieved documents differ with different chunking strategies and how it affects the output.
|
16 |
-
- Try out different embedding models (
|
17 |
- Deliverables: Demonstrate how retrieved documents differ with different embedding models and how they affect the output. Provide an estimate of how the time to embed the chunks and DB ingestion time differs (happening in **prep_scrips/lancedb_setup.py**).
|
18 |
- Add a re-ranker (cross-encoder) to the pipeline. Start with sentence-transformers pages on cross-encoders [1](https://www.sbert.net/examples/applications/cross-encoder/README.html) [2](https://www.sbert.net/examples/applications/retrieve_rerank/README.html), then pick a [pretrained cross-encoder](https://www.sbert.net/docs/pretrained-models/ce-msmarco.html), e.g. **cross-encoder/ms-marco-MiniLM-L-12-v2**. Don't forget to increase the number of *retrieved* documents when using re-ranker. The number of documents used as context should stay the same.
|
19 |
- Deliverables: Demonstrate how retrieved documents differ after adding a re-ranker and how it affects the output. Provide an estimate of how latency changes.
|
|
|
13 |
- TODOs:
|
14 |
- Experiment with chunking, see how it affects the results. When deciding how to chunk it helps to think about what kind of chunks you'd like to see as context to your queries.
|
15 |
- Deliverables: Demonstrate how retrieved documents differ with different chunking strategies and how it affects the output.
|
16 |
+
- Try out different embedding models (EMBED_NAME). Good models to start with are **sentence-transformers/all-MiniLM-L6-v2** - lightweight, **thenlper/gte-large** - relatively heavy but more powerful.
|
17 |
- Deliverables: Demonstrate how retrieved documents differ with different embedding models and how they affect the output. Provide an estimate of how the time to embed the chunks and DB ingestion time differs (happening in **prep_scrips/lancedb_setup.py**).
|
18 |
- Add a re-ranker (cross-encoder) to the pipeline. Start with sentence-transformers pages on cross-encoders [1](https://www.sbert.net/examples/applications/cross-encoder/README.html) [2](https://www.sbert.net/examples/applications/retrieve_rerank/README.html), then pick a [pretrained cross-encoder](https://www.sbert.net/docs/pretrained-models/ce-msmarco.html), e.g. **cross-encoder/ms-marco-MiniLM-L-12-v2**. Don't forget to increase the number of *retrieved* documents when using re-ranker. The number of documents used as context should stay the same.
|
19 |
- Deliverables: Demonstrate how retrieved documents differ after adding a re-ranker and how it affects the output. Provide an estimate of how latency changes.
|
gradio_app/app.py
CHANGED
@@ -9,6 +9,7 @@ import logging
|
|
9 |
from time import perf_counter
|
10 |
|
11 |
import gradio as gr
|
|
|
12 |
from jinja2 import Environment, FileSystemLoader
|
13 |
|
14 |
from gradio_app.backend.ChatGptInteractor import num_tokens_from_messages
|
@@ -31,10 +32,10 @@ context_html_template = env.get_template('context_html_template.j2')
|
|
31 |
# Examples
|
32 |
examples = [
|
33 |
'What is BERT?',
|
34 |
-
'Tell me about
|
|
|
35 |
'What is the capital of China?',
|
36 |
'Why is the sky blue?',
|
37 |
-
'Who won the mens world cup in 2014?',
|
38 |
]
|
39 |
|
40 |
|
@@ -58,7 +59,7 @@ def bot(history, api_kind):
|
|
58 |
# Retrieve documents relevant to query
|
59 |
document_start = perf_counter()
|
60 |
|
61 |
-
query_vec = embedder.
|
62 |
documents = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME).limit(top_k_rank).to_list()
|
63 |
thresh_dist = max(thresh_dist, min(d['_distance'] for d in documents))
|
64 |
documents = [d for d in documents if d['_distance'] <= thresh_dist]
|
@@ -69,10 +70,11 @@ def bot(history, api_kind):
|
|
69 |
|
70 |
while len(documents) != 0:
|
71 |
context = context_template.render(documents=documents)
|
72 |
-
|
|
|
73 |
messages = construct_openai_messages(context, history)
|
74 |
-
num_tokens = num_tokens_from_messages(messages,
|
75 |
-
if num_tokens + 512 < context_lengths[
|
76 |
break
|
77 |
documents.pop()
|
78 |
else:
|
|
|
9 |
from time import perf_counter
|
10 |
|
11 |
import gradio as gr
|
12 |
+
import markdown
|
13 |
from jinja2 import Environment, FileSystemLoader
|
14 |
|
15 |
from gradio_app.backend.ChatGptInteractor import num_tokens_from_messages
|
|
|
32 |
# Examples
|
33 |
examples = [
|
34 |
'What is BERT?',
|
35 |
+
'Tell me about GPT',
|
36 |
+
'How to use accelerate in google colab?',
|
37 |
'What is the capital of China?',
|
38 |
'Why is the sky blue?',
|
|
|
39 |
]
|
40 |
|
41 |
|
|
|
59 |
# Retrieve documents relevant to query
|
60 |
document_start = perf_counter()
|
61 |
|
62 |
+
query_vec = embedder.embed(query)[0]
|
63 |
documents = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME).limit(top_k_rank).to_list()
|
64 |
thresh_dist = max(thresh_dist, min(d['_distance'] for d in documents))
|
65 |
documents = [d for d in documents if d['_distance'] <= thresh_dist]
|
|
|
70 |
|
71 |
while len(documents) != 0:
|
72 |
context = context_template.render(documents=documents)
|
73 |
+
documents_html = [markdown.markdown(d) for d in documents]
|
74 |
+
context_html = context_html_template.render(documents=documents_html)
|
75 |
messages = construct_openai_messages(context, history)
|
76 |
+
num_tokens = num_tokens_from_messages(messages, LLM_NAME)
|
77 |
+
if num_tokens + 512 < context_lengths[LLM_NAME]:
|
78 |
break
|
79 |
documents.pop()
|
80 |
else:
|
gradio_app/backend/embedders.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import openai
|
3 |
+
from sentence_transformers import SentenceTransformer
|
4 |
+
from abc import ABC, abstractmethod
|
5 |
+
|
6 |
+
|
7 |
+
class Embedder(ABC):
|
8 |
+
@abstractmethod
|
9 |
+
def embed(self, texts):
|
10 |
+
pass
|
11 |
+
|
12 |
+
|
13 |
+
class HfEmbedder(Embedder):
|
14 |
+
def __init__(self, model_name):
|
15 |
+
self.model = SentenceTransformer(model_name)
|
16 |
+
self.model.eval()
|
17 |
+
|
18 |
+
@torch.no_grad()
|
19 |
+
def embed(self, texts):
|
20 |
+
encoded = self.model.encode(texts, normalize_embeddings=True)
|
21 |
+
return [list(vec) for vec in encoded]
|
22 |
+
|
23 |
+
|
24 |
+
class OpenAIEmbedder(Embedder):
|
25 |
+
def __init__(self, model_name):
|
26 |
+
self.model_name = model_name
|
27 |
+
|
28 |
+
def embed(self, texts):
|
29 |
+
responses = openai.Embedding.create(input=texts, engine=self.model_name)
|
30 |
+
return [response['embedding'] for response in responses['data']]
|
31 |
+
|
32 |
+
|
33 |
+
class EmbedderFactory:
|
34 |
+
@staticmethod
|
35 |
+
def get_embedder(type):
|
36 |
+
if type == "sentence-transformers/all-MiniLM-L6-v2":
|
37 |
+
return HfEmbedder(type)
|
38 |
+
elif type == "text-embedding-ada-002":
|
39 |
+
return OpenAIEmbedder(type)
|
40 |
+
else:
|
41 |
+
raise ValueError(f"Unsupported embedder type: {type}")
|
42 |
+
|
43 |
+
|
gradio_app/backend/query_llm.py
CHANGED
@@ -2,19 +2,17 @@ import gradio as gr
|
|
2 |
|
3 |
from typing import Any, Dict, Generator, List
|
4 |
|
5 |
-
from huggingface_hub import InferenceClient
|
6 |
-
from transformers import AutoTokenizer
|
7 |
from jinja2 import Environment, FileSystemLoader
|
8 |
|
9 |
from settings import *
|
10 |
from gradio_app.backend.ChatGptInteractor import *
|
11 |
|
12 |
|
13 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
hf_client = InferenceClient(HF_LLM_NAME, token=HF_TOKEN)
|
18 |
|
19 |
|
20 |
def format_prompt(message: str, api_kind: str):
|
@@ -125,7 +123,7 @@ def construct_openai_messages(context, history):
|
|
125 |
|
126 |
|
127 |
def generate_openai(messages):
|
128 |
-
cgi = ChatGptInteractor(model_name=
|
129 |
for part in cgi.chat_completion(messages, max_tokens=512, temperature=0, stream=True):
|
130 |
yield cgi.get_stream_text(part)
|
131 |
|
@@ -162,7 +160,7 @@ def _generate_openai(prompt: str, history: str, temperature: float = 0.9, max_ne
|
|
162 |
|
163 |
try:
|
164 |
stream = openai.ChatCompletion.create(
|
165 |
-
model=
|
166 |
messages=formatted_prompt,
|
167 |
**generate_kwargs,
|
168 |
stream=True
|
|
|
2 |
|
3 |
from typing import Any, Dict, Generator, List
|
4 |
|
5 |
+
# from huggingface_hub import InferenceClient
|
6 |
+
# from transformers import AutoTokenizer
|
7 |
from jinja2 import Environment, FileSystemLoader
|
8 |
|
9 |
from settings import *
|
10 |
from gradio_app.backend.ChatGptInteractor import *
|
11 |
|
12 |
|
13 |
+
# tokenizer = AutoTokenizer.from_pretrained(LLM_NAME)
|
14 |
+
# HF_TOKEN = None
|
15 |
+
# hf_client = InferenceClient(LLM_NAME, token=HF_TOKEN)
|
|
|
|
|
16 |
|
17 |
|
18 |
def format_prompt(message: str, api_kind: str):
|
|
|
123 |
|
124 |
|
125 |
def generate_openai(messages):
|
126 |
+
cgi = ChatGptInteractor(model_name=LLM_NAME)
|
127 |
for part in cgi.chat_completion(messages, max_tokens=512, temperature=0, stream=True):
|
128 |
yield cgi.get_stream_text(part)
|
129 |
|
|
|
160 |
|
161 |
try:
|
162 |
stream = openai.ChatCompletion.create(
|
163 |
+
model=LLM_NAME,
|
164 |
messages=formatted_prompt,
|
165 |
**generate_kwargs,
|
166 |
stream=True
|
gradio_app/backend/semantic_search.py
CHANGED
@@ -1,14 +1,14 @@
|
|
1 |
import logging
|
2 |
import lancedb
|
3 |
-
from sentence_transformers import SentenceTransformer
|
4 |
|
|
|
5 |
from settings import *
|
6 |
|
7 |
|
8 |
# Setting up the logging
|
9 |
logging.basicConfig(level=logging.INFO)
|
10 |
logger = logging.getLogger(__name__)
|
11 |
-
embedder =
|
12 |
|
13 |
db = lancedb.connect(LANCEDB_DIRECTORY)
|
14 |
table = db.open_table(LANCEDB_TABLE_NAME)
|
|
|
1 |
import logging
|
2 |
import lancedb
|
|
|
3 |
|
4 |
+
from gradio_app.backend.embedders import EmbedderFactory
|
5 |
from settings import *
|
6 |
|
7 |
|
8 |
# Setting up the logging
|
9 |
logging.basicConfig(level=logging.INFO)
|
10 |
logger = logging.getLogger(__name__)
|
11 |
+
embedder = EmbedderFactory.get_embedder(EMBED_NAME)
|
12 |
|
13 |
db = lancedb.connect(LANCEDB_DIRECTORY)
|
14 |
table = db.open_table(LANCEDB_TABLE_NAME)
|
gradio_app/templates/context_html_template.j2
CHANGED
@@ -11,85 +11,25 @@
|
|
11 |
font-family: "Source Sans Pro";
|
12 |
}
|
13 |
|
14 |
-
.instructions > * {
|
15 |
-
color: #111 !important;
|
16 |
-
}
|
17 |
-
|
18 |
-
details.doc-box * {
|
19 |
-
color: #111 !important;
|
20 |
-
}
|
21 |
-
|
22 |
-
.dark {
|
23 |
-
background: #111;
|
24 |
-
color: white;
|
25 |
-
}
|
26 |
-
|
27 |
.doc-box {
|
28 |
padding: 10px;
|
29 |
margin-top: 10px;
|
30 |
-
background-color: #
|
31 |
border-radius: 6px;
|
32 |
color: #111 !important;
|
33 |
-
max-width: 700px;
|
34 |
-
box-shadow: rgba(0, 0, 0, 0.2) 0px 1px 2px 0px;
|
35 |
-
}
|
36 |
-
|
37 |
-
.doc-full {
|
38 |
-
margin: 10px 14px;
|
39 |
-
line-height: 1.6rem;
|
40 |
-
}
|
41 |
-
|
42 |
-
.instructions {
|
43 |
-
color: #111 !important;
|
44 |
-
background: #b7bdfd;
|
45 |
-
display: block;
|
46 |
-
border-radius: 6px;
|
47 |
-
padding: 6px 10px;
|
48 |
-
line-height: 1.6rem;
|
49 |
-
max-width: 700px;
|
50 |
-
box-shadow: rgba(0, 0, 0, 0.2) 0px 1px 2px 0px;
|
51 |
-
}
|
52 |
-
|
53 |
-
.query {
|
54 |
-
color: #111 !important;
|
55 |
-
background: #ffbcbc;
|
56 |
-
display: block;
|
57 |
-
border-radius: 6px;
|
58 |
-
padding: 6px 10px;
|
59 |
-
line-height: 1.6rem;
|
60 |
-
max-width: 700px;
|
61 |
box-shadow: rgba(0, 0, 0, 0.2) 0px 1px 2px 0px;
|
62 |
}
|
63 |
</style>
|
64 |
</head>
|
65 |
<body>
|
66 |
-
|
67 |
<h2>Context:</h2>
|
68 |
{% for doc in documents %}
|
69 |
-
<
|
70 |
-
|
71 |
-
|
72 |
-
</summary>
|
73 |
-
<div class="doc-full">{{ doc }}</div>
|
74 |
-
</details>
|
75 |
{% endfor %}
|
76 |
-
</div>
|
77 |
|
78 |
-
<script>
|
79 |
-
document.addEventListener("DOMContentLoaded", function() {
|
80 |
-
const detailsElements = document.querySelectorAll('.doc-box');
|
81 |
|
82 |
-
detailsElements.forEach(detail => {
|
83 |
-
detail.addEventListener('toggle', function() {
|
84 |
-
const docShort = this.querySelector('.doc-short');
|
85 |
-
if (this.open) {
|
86 |
-
docShort.style.display = 'none';
|
87 |
-
} else {
|
88 |
-
docShort.style.display = 'inline';
|
89 |
-
}
|
90 |
-
});
|
91 |
-
});
|
92 |
-
});
|
93 |
-
</script>
|
94 |
</body>
|
95 |
</html>
|
|
|
11 |
font-family: "Source Sans Pro";
|
12 |
}
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
.doc-box {
|
15 |
padding: 10px;
|
16 |
margin-top: 10px;
|
17 |
+
background-color: #374151;
|
18 |
border-radius: 6px;
|
19 |
color: #111 !important;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
box-shadow: rgba(0, 0, 0, 0.2) 0px 1px 2px 0px;
|
21 |
}
|
22 |
</style>
|
23 |
</head>
|
24 |
<body>
|
25 |
+
|
26 |
<h2>Context:</h2>
|
27 |
{% for doc in documents %}
|
28 |
+
<div class="doc-box">
|
29 |
+
{{ doc }}
|
30 |
+
</div>
|
|
|
|
|
|
|
31 |
{% endfor %}
|
|
|
32 |
|
|
|
|
|
|
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
</body>
|
35 |
</html>
|
prep_scripts/lancedb_setup.py
CHANGED
@@ -1,36 +1,29 @@
|
|
1 |
import shutil
|
2 |
-
import traceback
|
3 |
|
4 |
import lancedb
|
5 |
-
import
|
6 |
import pyarrow as pa
|
7 |
import pandas as pd
|
8 |
from pathlib import Path
|
9 |
import tqdm
|
10 |
import numpy as np
|
11 |
|
12 |
-
from
|
13 |
-
|
14 |
from markdown_to_text import *
|
15 |
from settings import *
|
16 |
|
17 |
|
|
|
|
|
|
|
|
|
|
|
18 |
shutil.rmtree(LANCEDB_DIRECTORY, ignore_errors=True)
|
19 |
db = lancedb.connect(LANCEDB_DIRECTORY)
|
20 |
batch_size = 32
|
21 |
|
22 |
-
model = SentenceTransformer(EMB_MODEL_NAME)
|
23 |
-
model.eval()
|
24 |
-
|
25 |
-
if torch.backends.mps.is_available():
|
26 |
-
device = "mps"
|
27 |
-
elif torch.cuda.is_available():
|
28 |
-
device = "cuda"
|
29 |
-
else:
|
30 |
-
device = "cpu"
|
31 |
-
|
32 |
schema = pa.schema([
|
33 |
-
pa.field(VECTOR_COLUMN_NAME, pa.list_(pa.float32(), emb_sizes[
|
34 |
pa.field(TEXT_COLUMN_NAME, pa.string()),
|
35 |
pa.field(DOCUMENT_PATH_COLUMN_NAME, pa.string()),
|
36 |
])
|
@@ -49,17 +42,18 @@ for file in files:
|
|
49 |
print(f'Skipped {file_ext} extension: {file}')
|
50 |
continue
|
51 |
|
52 |
-
doc_header = ' / '.join(split_path(file_path)) + ':\n\n'
|
53 |
with open(file, encoding='utf-8') as f:
|
54 |
f = f.read()
|
55 |
f = remove_comments(f)
|
56 |
f = split_markdown(f)
|
57 |
-
chunks.extend((
|
58 |
|
59 |
from matplotlib import pyplot as plt
|
60 |
plt.hist([len(c) for c, d in chunks], bins=100)
|
61 |
plt.show()
|
62 |
|
|
|
|
|
63 |
for i in tqdm.tqdm(range(0, int(np.ceil(len(chunks) / batch_size)))):
|
64 |
texts, doc_paths = [], []
|
65 |
for text, doc_path in chunks[i * batch_size:(i + 1) * batch_size]:
|
@@ -67,9 +61,7 @@ for i in tqdm.tqdm(range(0, int(np.ceil(len(chunks) / batch_size)))):
|
|
67 |
texts.append(text)
|
68 |
doc_paths.append(doc_path)
|
69 |
|
70 |
-
encoded =
|
71 |
-
encoded = [list(vec) for vec in encoded]
|
72 |
-
|
73 |
df = pd.DataFrame({
|
74 |
VECTOR_COLUMN_NAME: encoded,
|
75 |
TEXT_COLUMN_NAME: texts,
|
@@ -79,10 +71,4 @@ for i in tqdm.tqdm(range(0, int(np.ceil(len(chunks) / batch_size)))):
|
|
79 |
tbl.add(df)
|
80 |
|
81 |
|
82 |
-
# '''
|
83 |
-
# create ivf-pd index https://lancedb.github.io/lancedb/ann_indexes/
|
84 |
-
# with the size of the transformer docs, index is not really needed
|
85 |
-
# but we'll do it for demonstration purposes
|
86 |
-
# '''
|
87 |
-
# tbl.create_index(num_partitions=256, num_sub_vectors=96, vector_column_name=VECTOR_COLUMN_NAME)
|
88 |
|
|
|
1 |
import shutil
|
|
|
2 |
|
3 |
import lancedb
|
4 |
+
import openai
|
5 |
import pyarrow as pa
|
6 |
import pandas as pd
|
7 |
from pathlib import Path
|
8 |
import tqdm
|
9 |
import numpy as np
|
10 |
|
11 |
+
from gradio_app.backend.embedders import EmbedderFactory
|
|
|
12 |
from markdown_to_text import *
|
13 |
from settings import *
|
14 |
|
15 |
|
16 |
+
with open('data/openaikey.txt') as f:
|
17 |
+
OPENAI_KEY = f.read().strip()
|
18 |
+
openai.api_key = OPENAI_KEY
|
19 |
+
|
20 |
+
|
21 |
shutil.rmtree(LANCEDB_DIRECTORY, ignore_errors=True)
|
22 |
db = lancedb.connect(LANCEDB_DIRECTORY)
|
23 |
batch_size = 32
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
schema = pa.schema([
|
26 |
+
pa.field(VECTOR_COLUMN_NAME, pa.list_(pa.float32(), emb_sizes[EMBED_NAME])),
|
27 |
pa.field(TEXT_COLUMN_NAME, pa.string()),
|
28 |
pa.field(DOCUMENT_PATH_COLUMN_NAME, pa.string()),
|
29 |
])
|
|
|
42 |
print(f'Skipped {file_ext} extension: {file}')
|
43 |
continue
|
44 |
|
|
|
45 |
with open(file, encoding='utf-8') as f:
|
46 |
f = f.read()
|
47 |
f = remove_comments(f)
|
48 |
f = split_markdown(f)
|
49 |
+
chunks.extend((chunk, os.path.abspath(file)) for chunk in f)
|
50 |
|
51 |
from matplotlib import pyplot as plt
|
52 |
plt.hist([len(c) for c, d in chunks], bins=100)
|
53 |
plt.show()
|
54 |
|
55 |
+
embedder = EmbedderFactory.get_embedder(EMBED_NAME)
|
56 |
+
|
57 |
for i in tqdm.tqdm(range(0, int(np.ceil(len(chunks) / batch_size)))):
|
58 |
texts, doc_paths = [], []
|
59 |
for text, doc_path in chunks[i * batch_size:(i + 1) * batch_size]:
|
|
|
61 |
texts.append(text)
|
62 |
doc_paths.append(doc_path)
|
63 |
|
64 |
+
encoded = embedder.embed(texts)
|
|
|
|
|
65 |
df = pd.DataFrame({
|
66 |
VECTOR_COLUMN_NAME: encoded,
|
67 |
TEXT_COLUMN_NAME: texts,
|
|
|
71 |
tbl.add(df)
|
72 |
|
73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
|
prep_scripts/markdown_to_text.py
CHANGED
@@ -21,33 +21,26 @@ def remove_comments(md):
|
|
21 |
return re.sub(r'<!--((.|\n)*)-->', '', md)
|
22 |
|
23 |
|
24 |
-
header_pattern = re.compile(r'\n\s*\n(#{1,3})\s
|
25 |
|
26 |
|
27 |
def split_content(content):
|
|
|
28 |
_parts = content.split('\n\n')
|
29 |
parts = []
|
30 |
for p in _parts:
|
31 |
-
if len(p) <
|
32 |
parts.append(p)
|
33 |
else:
|
34 |
parts.extend(p.split('\n'))
|
35 |
|
36 |
res = ['']
|
37 |
for p in parts:
|
38 |
-
if len(res[-1]) + len(p) <
|
39 |
res[-1] += p + '\n\n'
|
40 |
else:
|
41 |
res.append(p + '\n\n')
|
42 |
|
43 |
-
if (
|
44 |
-
len(res) >= 2 and
|
45 |
-
len(res[-1]) < TEXT_CHUNK_SIZE / 4 and
|
46 |
-
len(res[-2]) < TEXT_CHUNK_SIZE
|
47 |
-
):
|
48 |
-
res[-2] += res[-1]
|
49 |
-
res.pop()
|
50 |
-
|
51 |
return res
|
52 |
|
53 |
|
@@ -65,20 +58,30 @@ def split_markdown(md):
|
|
65 |
chunk = ''
|
66 |
for i in sorted(name_hierarchy):
|
67 |
if len(name_hierarchy[i]) != 0:
|
68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
chunk += content
|
71 |
chunk = chunk.strip()
|
72 |
res.append(chunk)
|
73 |
|
74 |
-
|
|
|
75 |
headers = list(header_pattern.finditer(md))
|
|
|
|
|
|
|
76 |
name_hierarchy = {i: '' for i in (1, 2, 3)}
|
77 |
res = []
|
78 |
for i in range(len(headers)):
|
79 |
header = headers[i]
|
80 |
level = len(header.group(1))
|
81 |
-
name = header.group().strip()
|
82 |
name_hierarchy[level] = name
|
83 |
if i == 0 and header.start() != 0:
|
84 |
construct_chunks(md[:header.start()])
|
|
|
21 |
return re.sub(r'<!--((.|\n)*)-->', '', md)
|
22 |
|
23 |
|
24 |
+
header_pattern = re.compile(r'\n\s*\n(#{1,3})\s(.*)\n\s*\n')
|
25 |
|
26 |
|
27 |
def split_content(content):
|
28 |
+
text_chunk_size = context_lengths[EMBED_NAME] - 32
|
29 |
_parts = content.split('\n\n')
|
30 |
parts = []
|
31 |
for p in _parts:
|
32 |
+
if len(p) < text_chunk_size:
|
33 |
parts.append(p)
|
34 |
else:
|
35 |
parts.extend(p.split('\n'))
|
36 |
|
37 |
res = ['']
|
38 |
for p in parts:
|
39 |
+
if len(res[-1]) + len(p) < text_chunk_size:
|
40 |
res[-1] += p + '\n\n'
|
41 |
else:
|
42 |
res.append(p + '\n\n')
|
43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
return res
|
45 |
|
46 |
|
|
|
58 |
chunk = ''
|
59 |
for i in sorted(name_hierarchy):
|
60 |
if len(name_hierarchy[i]) != 0:
|
61 |
+
j = i + 1
|
62 |
+
while j in name_hierarchy:
|
63 |
+
if name_hierarchy[j].find(name_hierarchy[i]) != -1:
|
64 |
+
break
|
65 |
+
j += 1
|
66 |
+
else:
|
67 |
+
chunk += f'{"#" * (i + 1)}{name_hierarchy[i]}\n\n'
|
68 |
|
69 |
chunk += content
|
70 |
chunk = chunk.strip()
|
71 |
res.append(chunk)
|
72 |
|
73 |
+
# to find a header at the top of a file
|
74 |
+
md = f'\n\n{md}'
|
75 |
headers = list(header_pattern.finditer(md))
|
76 |
+
# only first header can be first-level
|
77 |
+
headers = [h for i, h in enumerate(headers) if i == 0 or len(h.group(1)) > 1]
|
78 |
+
|
79 |
name_hierarchy = {i: '' for i in (1, 2, 3)}
|
80 |
res = []
|
81 |
for i in range(len(headers)):
|
82 |
header = headers[i]
|
83 |
level = len(header.group(1))
|
84 |
+
name = header.group(2).strip()
|
85 |
name_hierarchy[level] = name
|
86 |
if i == 0 and header.start() != 0:
|
87 |
construct_chunks(md[:header.start()])
|
settings.py
CHANGED
@@ -1,22 +1,26 @@
|
|
1 |
MARKDOWN_SOURCE_DIR = "data/transformers/docs/source/en/"
|
2 |
-
EMB_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
3 |
LANCEDB_DIRECTORY = "data/lancedb"
|
4 |
LANCEDB_TABLE_NAME = "table"
|
5 |
VECTOR_COLUMN_NAME = "embedding"
|
6 |
TEXT_COLUMN_NAME = "text"
|
7 |
DOCUMENT_PATH_COLUMN_NAME = "document_path"
|
8 |
-
HF_LLM_NAME = "mistralai/Mistral-7B-Instruct-v0.1"
|
9 |
-
OPENAI_LLM_NAME = "gpt-3.5-turbo"
|
10 |
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
13 |
|
14 |
emb_sizes = {
|
15 |
"sentence-transformers/all-MiniLM-L6-v2": 384,
|
16 |
-
"thenlper/gte-large":
|
|
|
17 |
}
|
18 |
|
19 |
context_lengths = {
|
20 |
"mistralai/Mistral-7B-Instruct-v0.1": 4096,
|
21 |
"gpt-3.5-turbo": 4096,
|
|
|
|
|
|
|
22 |
}
|
|
|
1 |
MARKDOWN_SOURCE_DIR = "data/transformers/docs/source/en/"
|
|
|
2 |
LANCEDB_DIRECTORY = "data/lancedb"
|
3 |
LANCEDB_TABLE_NAME = "table"
|
4 |
VECTOR_COLUMN_NAME = "embedding"
|
5 |
TEXT_COLUMN_NAME = "text"
|
6 |
DOCUMENT_PATH_COLUMN_NAME = "document_path"
|
|
|
|
|
7 |
|
8 |
+
# LLM_NAME = "mistralai/Mistral-7B-Instruct-v0.1"
|
9 |
+
LLM_NAME = "gpt-3.5-turbo"
|
10 |
+
# EMBED_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
11 |
+
EMBED_NAME = "text-embedding-ada-002"
|
12 |
+
|
13 |
|
14 |
emb_sizes = {
|
15 |
"sentence-transformers/all-MiniLM-L6-v2": 384,
|
16 |
+
"thenlper/gte-large": 1024,
|
17 |
+
"text-embedding-ada-002": 1536,
|
18 |
}
|
19 |
|
20 |
context_lengths = {
|
21 |
"mistralai/Mistral-7B-Instruct-v0.1": 4096,
|
22 |
"gpt-3.5-turbo": 4096,
|
23 |
+
"sentence-transformers/all-MiniLM-L6-v2": 128,
|
24 |
+
"thenlper/gte-large": 512,
|
25 |
+
"text-embedding-ada-002": 8191,
|
26 |
}
|