Spaces:
Running
Running
#!/usr/bin/env python3 | |
"""Manage the project database.""" | |
# %% IMPORTS | |
import argparse | |
import logging | |
import re | |
import sys | |
import typing as T | |
import lib | |
# %% LOGGING | |
logging.basicConfig( | |
level=logging.DEBUG, | |
format="[%(asctime)s][%(levelname)s] %(message)s", | |
) | |
# %% PARSING | |
PARSER = argparse.ArgumentParser(description=__doc__) | |
PARSER.add_argument("files", type=argparse.FileType("r"), nargs="+") | |
PARSER.add_argument("--database", type=str, default=lib.DATABASE_PATH) | |
PARSER.add_argument("--collection", type=str, default=lib.DATABASE_COLLECTION) | |
# %% FUNCTIONS | |
def segment_text(text: str, pattern: str) -> T.Iterator[tuple[str, str]]: | |
"""Segment the text in title and content pair by pattern.""" | |
splits = re.split(pattern, text, flags=re.MULTILINE) | |
pairs = zip(splits[1::2], splits[2::2]) | |
return pairs | |
def import_file( | |
file: T.TextIO, | |
collection: lib.Collection, | |
encoding_function: T.Callable, | |
max_output_tokens: int = lib.ENCODING_OUTPUT_LIMIT, | |
) -> tuple[int, int]: | |
"""Import a markdown file to a database collection.""" | |
n_chars = 0 | |
n_tokens = 0 | |
text = file.read() | |
filename = file.name | |
segments_h1 = segment_text(text=text, pattern=r"^# (.+)") | |
for h1, h1_text in segments_h1: | |
logging.debug('\t- H1: "%s" (%d)', h1, len(h1_text)) | |
segments_h2 = segment_text(text=h1_text, pattern=r"^## (.+)") | |
for h2, content in segments_h2: | |
content_chars = len(content) | |
content_tokens = len(encoding_function(content)) | |
logging.debug('\t\t- H2: "%s" (%d)', h2, content_chars) | |
id_ = f"{filename} # {h1} ## {h2}" # unique doc id | |
document = f"# {h1}\n\n## {h2}\n\n{content.strip()}" | |
metadata = {"filename": filename, "h1": h1, "h2": h2} | |
assert ( | |
content_tokens < max_output_tokens | |
), f"Content is too long ({content_tokens}): #{h1} ##{h2}" | |
collection.add(ids=id_, documents=document, metadatas=metadata) | |
n_tokens += content_tokens | |
n_chars += content_chars | |
return n_chars, n_tokens | |
def main(args: list[str] | None = None) -> int: | |
"""Main function of the script.""" | |
# parsing | |
opts = PARSER.parse_args(args) | |
# database | |
database_path = opts.database | |
logging.info("Database path: %s", database_path) | |
client = lib.get_database_client(path=database_path) | |
logging.info("- Reseting database client: %s", client.reset()) | |
# encoding | |
encoding_function = lib.get_encoding_function() | |
logging.info("Encoding function: %s", encoding_function) | |
# embedding | |
embedding_function = lib.get_embedding_function() | |
logging.info("Embedding function: %s", embedding_function) | |
# collection | |
database_collection = opts.collection | |
logging.info("Database collection: %s", database_collection) | |
collection = client.create_collection( | |
name=database_collection, embedding_function=embedding_function | |
) | |
# files | |
for i, file in enumerate(opts.files): | |
logging.info("Importing file %d: %s", i, file.name) | |
n_chars, n_tokens = import_file( | |
file=file, collection=collection, encoding_function=encoding_function | |
) | |
logging.info( | |
"- Docs imported from file %s: %d chars | %d tokens", i, n_chars, n_tokens | |
) | |
# return | |
return 0 | |
# %% ENTRYPOINTS | |
if __name__ == "__main__": | |
sys.exit(main()) | |