resume / database.py
Médéric Hurier (Fmind)
First release
ed67987
raw
history blame
3.45 kB
#!/usr/bin/env python3
"""Manage the project database."""
# %% IMPORTS
import argparse
import logging
import re
import sys
import typing as T
import lib
# %% LOGGING
logging.basicConfig(
level=logging.DEBUG,
format="[%(asctime)s][%(levelname)s] %(message)s",
)
# %% PARSING
PARSER = argparse.ArgumentParser(description=__doc__)
PARSER.add_argument("files", type=argparse.FileType("r"), nargs="+")
PARSER.add_argument("--database", type=str, default=lib.DATABASE_PATH)
PARSER.add_argument("--collection", type=str, default=lib.DATABASE_COLLECTION)
# %% FUNCTIONS
def segment_text(text: str, pattern: str) -> T.Iterator[tuple[str, str]]:
"""Segment the text in title and content pair by pattern."""
splits = re.split(pattern, text, flags=re.MULTILINE)
pairs = zip(splits[1::2], splits[2::2])
return pairs
def import_file(
file: T.TextIO,
collection: lib.Collection,
encoding_function: T.Callable,
max_output_tokens: int = lib.ENCODING_OUTPUT_LIMIT,
) -> tuple[int, int]:
"""Import a markdown file to a database collection."""
n_chars = 0
n_tokens = 0
text = file.read()
filename = file.name
segments_h1 = segment_text(text=text, pattern=r"^# (.+)")
for h1, h1_text in segments_h1:
logging.debug('\t- H1: "%s" (%d)', h1, len(h1_text))
segments_h2 = segment_text(text=h1_text, pattern=r"^## (.+)")
for h2, content in segments_h2:
content_chars = len(content)
content_tokens = len(encoding_function(content))
logging.debug('\t\t- H2: "%s" (%d)', h2, content_chars)
id_ = f"{filename} # {h1} ## {h2}" # unique doc id
document = f"# {h1}\n\n## {h2}\n\n{content.strip()}"
metadata = {"filename": filename, "h1": h1, "h2": h2}
assert (
content_tokens < max_output_tokens
), f"Content is too long ({content_tokens}): #{h1} ##{h2}"
collection.add(ids=id_, documents=document, metadatas=metadata)
n_tokens += content_tokens
n_chars += content_chars
return n_chars, n_tokens
def main(args: list[str] | None = None) -> int:
"""Main function of the script."""
# parsing
opts = PARSER.parse_args(args)
# database
database_path = opts.database
logging.info("Database path: %s", database_path)
client = lib.get_database_client(path=database_path)
logging.info("- Reseting database client: %s", client.reset())
# encoding
encoding_function = lib.get_encoding_function()
logging.info("Encoding function: %s", encoding_function)
# embedding
embedding_function = lib.get_embedding_function()
logging.info("Embedding function: %s", embedding_function)
# collection
database_collection = opts.collection
logging.info("Database collection: %s", database_collection)
collection = client.create_collection(
name=database_collection, embedding_function=embedding_function
)
# files
for i, file in enumerate(opts.files):
logging.info("Importing file %d: %s", i, file.name)
n_chars, n_tokens = import_file(
file=file, collection=collection, encoding_function=encoding_function
)
logging.info(
"- Docs imported from file %s: %d chars | %d tokens", i, n_chars, n_tokens
)
# return
return 0
# %% ENTRYPOINTS
if __name__ == "__main__":
sys.exit(main())