resume / database.py
Last commit not found
raw
history blame
2.87 kB
#!/usr/bin/env python3
"""Manage the project database."""
# %% IMPORTS
import argparse
import logging
import re
import sys
import typing as T
import lib
# %% LOGGING
logging.basicConfig(
level=logging.DEBUG,
format="[%(asctime)s][%(levelname)s] %(message)s",
)
# %% PARSING
PARSER = argparse.ArgumentParser(description=__doc__)
PARSER.add_argument("files", type=argparse.FileType("r"), nargs="+")
PARSER.add_argument("--database", type=str, default=lib.DATABASE_PATH)
PARSER.add_argument("--collection", type=str, default=lib.DATABASE_COLLECTION)
# %% FUNCTIONS
def segment_text(text: str, pattern: str) -> T.Iterator[tuple[str, str]]:
"""Segment the text in title and content pair by pattern."""
splits = re.split(pattern, text, flags=re.MULTILINE)
pairs = zip(splits[1::2], splits[2::2])
return pairs
def import_file(file: T.TextIO, collection: lib.Collection) -> int:
"""Import a markdown file to a database collection."""
imported = 0
text = file.read()
filename = file.name
segments_h1 = segment_text(text=text, pattern=r"^# (.+)")
for h1, h1_text in segments_h1:
logging.debug('\t- H1: "%s" (%d)', h1, len(h1_text))
segments_h2 = segment_text(text=h1_text, pattern=r"^## (.+)")
for h2, content in segments_h2:
logging.debug('\t\t- H2: "%s" (%d)', h2, len(content))
id_ = f"{filename} # {h1} ## {h2}" # unique doc id
document = f"# {h1}\n\n## {h2}\n\n{content.strip()}"
metadata = {"filename": filename, "h1": h1, "h2": h2}
assert len(content) < 8000, f"Content is too long: #{h1} ##{h2}"
collection.add(ids=id_, documents=document, metadatas=metadata)
imported += len(document)
return imported
def main(args: list[str] | None = None) -> int:
"""Main function of the script."""
# parsing
opts = PARSER.parse_args(args)
# database
database_path = opts.database
logging.info("Database path: %s", database_path)
client = lib.get_database_client(path=database_path)
logging.info("- Reseting database client: %s", client.reset())
# embedding
embedding_function = lib.get_embedding_function()
logging.info("Embedding function: %s", embedding_function)
# collection
database_collection = opts.collection
logging.info("Database collection: %s", database_collection)
collection = client.create_collection(
name=database_collection, embedding_function=embedding_function
)
# files
for i, file in enumerate(opts.files):
logging.info("Importing file %d: %s", i, file.name)
imported = import_file(file=file, collection=collection)
logging.info("- Docs imported from file %s: %d chars", i, imported)
# return
return 0
# %% ENTRYPOINTS
if __name__ == "__main__":
sys.exit(main())