|
import configparser |
|
import logging |
|
import sqlite3 |
|
from typing import List, Dict, Any |
|
|
|
|
|
import requests |
|
|
|
from App_Function_Libraries.Chunk_Lib import improved_chunking_process |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config = configparser.ConfigParser() |
|
config.read('config.txt') |
|
chroma_db_path = config.get('Database', 'chroma_db_path', fallback='chroma_db') |
|
chroma_client = chromadb.PersistentClient(path=chroma_db_path) |
|
|
|
|
|
embedding_provider = config.get('Embeddings', 'provider', fallback='openai') |
|
embedding_model = config.get('Embeddings', 'model', fallback='text-embedding-3-small') |
|
embedding_api_key = config.get('Embeddings', 'api_key', fallback='') |
|
embedding_api_url = config.get('Embeddings', 'api_url', fallback='') |
|
|
|
|
|
chunk_options = { |
|
'method': config.get('Chunking', 'method', fallback='words'), |
|
'max_size': config.getint('Chunking', 'max_size', fallback=400), |
|
'overlap': config.getint('Chunking', 'overlap', fallback=200), |
|
'adaptive': config.getboolean('Chunking', 'adaptive', fallback=False), |
|
'multi_level': config.getboolean('Chunking', 'multi_level', fallback=False), |
|
'language': config.get('Chunking', 'language', fallback='english') |
|
} |
|
|
|
|
|
def auto_update_chroma_embeddings(media_id: int, content: str): |
|
""" |
|
Automatically update ChromaDB embeddings when a new item is ingested into the SQLite database. |
|
|
|
:param media_id: The ID of the newly ingested media item |
|
:param content: The content of the newly ingested media item |
|
""" |
|
collection_name = f"media_{media_id}" |
|
|
|
|
|
collection = chroma_client.get_or_create_collection(name=collection_name) |
|
|
|
|
|
existing_embeddings = collection.get(ids=[f"{media_id}_chunk_{i}" for i in range(len(content))]) |
|
|
|
if existing_embeddings and len(existing_embeddings) > 0: |
|
logging.info(f"Embeddings already exist for media ID {media_id}, skipping...") |
|
else: |
|
|
|
process_and_store_content(content, collection_name, media_id) |
|
logging.info(f"Updated ChromaDB embeddings for media ID: {media_id}") |
|
|
|
|
|
|
|
def process_and_store_content(content: str, collection_name: str, media_id: int): |
|
|
|
chunks = improved_chunking_process(content, chunk_options) |
|
texts = [chunk['text'] for chunk in chunks] |
|
|
|
|
|
embeddings = [create_embedding(text) for text in texts] |
|
|
|
|
|
ids = [f"{media_id}_chunk_{i}" for i in range(len(texts))] |
|
|
|
|
|
store_in_chroma(collection_name, texts, embeddings, ids) |
|
|
|
|
|
from App_Function_Libraries.DB_Manager import db |
|
with db.get_connection() as conn: |
|
cursor = conn.cursor() |
|
for text in texts: |
|
cursor.execute("INSERT INTO media_fts (content) VALUES (?)", (text,)) |
|
conn.commit() |
|
|
|
|
|
|
|
def store_in_chroma(collection_name: str, texts: List[str], embeddings: List[List[float]], ids: List[str]): |
|
collection = chroma_client.get_or_create_collection(name=collection_name) |
|
collection.add( |
|
documents=texts, |
|
embeddings=embeddings, |
|
ids=ids |
|
) |
|
|
|
|
|
def vector_search(collection_name: str, query: str, k: int = 10) -> List[str]: |
|
query_embedding = create_embedding(query) |
|
collection = chroma_client.get_collection(name=collection_name) |
|
results = collection.query( |
|
query_embeddings=[query_embedding], |
|
n_results=k |
|
) |
|
return results['documents'][0] |
|
|
|
|
|
def create_embedding(text: str) -> List[float]: |
|
if embedding_provider == 'openai': |
|
import openai |
|
openai.api_key = embedding_api_key |
|
response = openai.Embedding.create(input=text, model=embedding_model) |
|
return response['data'][0]['embedding'] |
|
elif embedding_provider == 'local': |
|
|
|
response = requests.post( |
|
embedding_api_url, |
|
json={"text": text, "model": embedding_model}, |
|
headers={"Authorization": f"Bearer {embedding_api_key}"} |
|
) |
|
return response.json()['embedding'] |
|
|
|
elif embedding_provider == 'huggingface': |
|
from transformers import AutoTokenizer, AutoModel |
|
import torch |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(embedding_model) |
|
model = AutoModel.from_pretrained(embedding_model) |
|
|
|
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
|
|
embeddings = outputs.last_hidden_state.mean(dim=1) |
|
return embeddings[0].tolist() |
|
else: |
|
raise ValueError(f"Unsupported embedding provider: {embedding_provider}") |
|
|
|
|
|
def create_all_embeddings(api_choice: str) -> str: |
|
try: |
|
global embedding_provider |
|
embedding_provider = api_choice |
|
|
|
all_content = get_all_content_from_database() |
|
|
|
if not all_content: |
|
return "No content found in the database." |
|
|
|
texts_to_embed = [] |
|
embeddings_to_store = [] |
|
ids_to_store = [] |
|
collection_name = "all_content_embeddings" |
|
|
|
|
|
collection = chroma_client.get_or_create_collection(name=collection_name) |
|
|
|
for content_item in all_content: |
|
media_id = content_item['id'] |
|
text = content_item['content'] |
|
|
|
|
|
embedding_exists = collection.get(ids=[f"doc_{media_id}"]) |
|
|
|
if embedding_exists: |
|
logging.info(f"Embedding already exists for media ID {media_id}, skipping...") |
|
continue |
|
|
|
|
|
embedding = create_embedding(text) |
|
|
|
|
|
texts_to_embed.append(text) |
|
embeddings_to_store.append(embedding) |
|
ids_to_store.append(f"doc_{media_id}") |
|
|
|
|
|
if texts_to_embed and embeddings_to_store: |
|
store_in_chroma(collection_name, texts_to_embed, embeddings_to_store, ids_to_store) |
|
|
|
return "Embeddings created and stored successfully for all new content." |
|
except Exception as e: |
|
logging.error(f"Error during embedding creation: {str(e)}") |
|
return f"Error: {str(e)}" |
|
|
|
|
|
def get_all_content_from_database() -> List[Dict[str, Any]]: |
|
""" |
|
Retrieve all media content from the database that requires embedding. |
|
|
|
Returns: |
|
List[Dict[str, Any]]: A list of dictionaries, each containing the media ID, content, title, and other relevant fields. |
|
""" |
|
try: |
|
from App_Function_Libraries.DB_Manager import db |
|
with db.get_connection() as conn: |
|
cursor = conn.cursor() |
|
cursor.execute(""" |
|
SELECT id, content, title, author, type |
|
FROM Media |
|
WHERE is_trash = 0 -- Exclude items marked as trash |
|
""") |
|
media_items = cursor.fetchall() |
|
|
|
|
|
all_content = [ |
|
{ |
|
'id': item[0], |
|
'content': item[1], |
|
'title': item[2], |
|
'author': item[3], |
|
'type': item[4] |
|
} |
|
for item in media_items |
|
] |
|
|
|
return all_content |
|
|
|
except sqlite3.Error as e: |
|
logging.error(f"Error retrieving all content from database: {e}") |
|
from App_Function_Libraries.SQLite_DB import DatabaseError |
|
raise DatabaseError(f"Error retrieving all content from database: {e}") |
|
|
|
|
|
|
|
|