|
|
|
|
|
import os |
|
import re |
|
import sys |
|
import threading |
|
import torch |
|
from sentence_transformers import SentenceTransformer, util |
|
from typing import Dict, List, Tuple, Set, LiteralString |
|
|
|
|
|
class EmbeddingContext: |
|
|
|
TOKEN_LEN_MAX_FOR_EMBEDDING = 512 |
|
|
|
|
|
lock = None |
|
model = None |
|
openai_client = None |
|
model_name = '' |
|
config_type = '' |
|
embedding_shape = None |
|
embedding_dtype = None |
|
embedding_device = None |
|
|
|
|
|
data = {} |
|
|
|
def __init__(self): |
|
try: |
|
from config import settings |
|
except: |
|
sys.path.append(os.path.abspath( |
|
os.path.join(os.path.dirname(__file__), '../..'))) |
|
from config import settings |
|
|
|
self.lock = threading.Lock() |
|
config_type = settings.embedding_api |
|
model_name = settings.embedding_model |
|
|
|
if config_type == 'sbert': |
|
self.model = SentenceTransformer(model_name, use_auth_token=False) |
|
self.model.max_seq_length = self.TOKEN_LEN_MAX_FOR_EMBEDDING |
|
print("Max Sequence Length:", self.model.max_seq_length) |
|
|
|
self.encode = self.encode_sbert |
|
if torch.cuda.is_available(): |
|
self.model = self.model.to('cuda') |
|
|
|
elif config_type == 'openai': |
|
from openai import OpenAI |
|
self.openai_client = OpenAI( |
|
|
|
api_key=settings.OPENAI_API_KEY, |
|
) |
|
self.encode = self.encode_openai |
|
|
|
self.model_name = model_name |
|
self.config_type = config_type |
|
|
|
tmp = self.encode(['tmp']) |
|
self.embedding_shape = tmp.shape[1:] |
|
self.embedding_dtype = tmp.dtype |
|
self.embedding_device = tmp.device |
|
|
|
def encode(self, texts_to_embed): |
|
pass |
|
|
|
def encode_sbert(self, texts_to_embed): |
|
return self.model.encode(texts_to_embed, show_progress_bar=True, convert_to_tensor=True, normalize_embeddings=True) |
|
|
|
def encode_openai(self, texts_to_embed): |
|
import math |
|
import time |
|
|
|
tokens_count = 0 |
|
for text in texts_to_embed: |
|
tokens_count += len(self.get_tokens(text)) |
|
|
|
chunks_num = math.ceil(tokens_count / 500000) |
|
chunk_size = math.ceil(len(texts_to_embed) / chunks_num) |
|
|
|
embeddings = [] |
|
for i in range(chunks_num): |
|
start = i * chunk_size |
|
end = start + chunk_size |
|
chunk = texts_to_embed[start:end] |
|
|
|
embeddings_tmp = self.openai_client.embeddings.create( |
|
model=self.model_name, |
|
input=chunk, |
|
).data |
|
|
|
if embeddings_tmp is None: |
|
break |
|
|
|
embeddings.extend(embeddings_tmp) |
|
|
|
if i < chunks_num - 1: |
|
time.sleep(60) |
|
|
|
return torch.stack([torch.tensor(embedding.embedding, dtype=torch.float32) for embedding in embeddings]) |
|
|
|
def get_tokens(self, text): |
|
if self.model: |
|
return self.model.tokenizer.tokenize(text) |
|
|
|
tokens = [] |
|
for token in re.split(r'(\W|\b)', text): |
|
if token.strip(): |
|
tokens.append(token) |
|
|
|
return tokens |
|
|
|
|
|
class SplitDocs: |
|
def split_in_topics(self, |
|
filedir: LiteralString = None, |
|
*, |
|
pattern_filename=r'(?<!navigation)\.(md|rst)', |
|
pattern_content_sub=r'---\nhide:[\s\S]+?---\s*', |
|
patterns_titles=( |
|
r'^# (.+)', r'^## (.+)', r'^### (.+)'), |
|
) -> List[Tuple[str, str]]: |
|
def matches_pattern(filename): |
|
return re.search(pattern_filename, filename) is not None |
|
|
|
def split_patterns_recursive(patterns, text, index=-1): |
|
sections = re.split(patterns[0], text, flags=re.MULTILINE) |
|
for i, section in enumerate(sections): |
|
if not section.strip(): |
|
continue |
|
is_match = bool(i & 1) |
|
if is_match: |
|
yield (index, section) |
|
elif len(patterns) > 1: |
|
for j, section_j in split_patterns_recursive(patterns[1:], section, index + 1): |
|
yield (j, section_j) |
|
else: |
|
yield (-1, section) |
|
|
|
for root, _, files in os.walk(filedir): |
|
for name in files: |
|
if not matches_pattern(name): |
|
continue |
|
|
|
full_path = os.path.join(root, name) |
|
with open(full_path, 'r', encoding='utf-8') as file: |
|
content = file.read() |
|
|
|
if pattern_content_sub: |
|
content = re.sub(pattern_content_sub, '', content) |
|
|
|
rel_path = full_path.replace(filedir, '').replace('\\', '/') |
|
|
|
|
|
patterns = (r'(```[\s\S]+?```)', *patterns_titles) |
|
|
|
last_titles = [] |
|
last_titles_index = [] |
|
content_accum = '' |
|
for i, section in split_patterns_recursive(patterns, content): |
|
if i < 0: |
|
content_accum += section |
|
continue |
|
if content_accum: |
|
yield rel_path, last_titles, content_accum |
|
content_accum = '' |
|
if not last_titles_index or i > last_titles_index[-1]: |
|
last_titles_index.append(i) |
|
last_titles.append(section) |
|
continue |
|
while len(last_titles_index) > 1 and i < last_titles_index[-1]: |
|
last_titles_index.pop() |
|
last_titles.pop() |
|
|
|
last_titles_index[-1] = i |
|
last_titles[-1] = section |
|
if content_accum or i != -1: |
|
yield rel_path, last_titles, content_accum |
|
|
|
def reduce_text(_self, text): |
|
text = re.sub(r'^\n+', '', text) |
|
text = re.sub(r'<.*?>', '', text) |
|
text = re.sub(r':\S*: ', '', text) |
|
text = re.sub(r'\s*\n+', '\n', text) |
|
return text |
|
|
|
def embedding_header(_self, rel_path, titles): |
|
return f"{rel_path}\n# {' | '.join(titles)}\n\n" |
|
|
|
def split_for_embedding(self, |
|
filedir: LiteralString = None, |
|
*, |
|
pattern_filename=r'(?<!navigation)\.(md|rst)', |
|
pattern_content_sub=r'---\nhide:[\s\S]+?---\s*', |
|
patterns_titles=( |
|
r'^# (.+)', r'^## (.+)', r'^### (.+)'), |
|
): |
|
tokenizer = EMBEDDING_CTX.model.tokenizer |
|
max_tokens = EMBEDDING_CTX.model.max_seq_length |
|
texts = [] |
|
|
|
for rel_path, titles, content in self.split_in_topics( |
|
filedir, pattern_filename=pattern_filename, pattern_content_sub=pattern_content_sub, patterns_titles=patterns_titles): |
|
header = self.embedding_header(rel_path, titles) |
|
tokens_pre_len = len(tokenizer.tokenize(header)) |
|
tokens_so_far = tokens_pre_len |
|
text_so_far = header |
|
for part in self.reduce_text(content).splitlines(): |
|
part += '\n' |
|
part_tokens_len = len(tokenizer.tokenize(part)) |
|
if tokens_so_far + part_tokens_len > max_tokens: |
|
texts.append(text_so_far) |
|
text_so_far = header |
|
tokens_so_far = tokens_pre_len |
|
text_so_far += part |
|
tokens_so_far += part_tokens_len |
|
|
|
if tokens_so_far != tokens_pre_len: |
|
texts.append(text_so_far) |
|
|
|
return texts |
|
|
|
|
|
EMBEDDING_CTX = EmbeddingContext() |
|
|