Odi / rag_backend.py
eaglesarezzo's picture
Update rag_backend.py
e6fab7f verified
raw
history blame
3.19 kB
import os
from llama_cpp import Llama
from llama_index.core import VectorStoreIndex, Settings, SimpleDirectoryReader, load_index_from_storage, StorageContext
from llama_index.core.node_parser import SentenceSplitter
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
Settings.llm = None
class Backend:
def __init__(self):
self.llm = None
self.llm_model = None
self.embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
self.PERSIST_DIR = "./db"
os.makedirs(self.PERSIST_DIR, exist_ok=True)
def load_model(self, model_path):
self.llm = Llama(
model_path=f"models/{model_path}",
flash_attn=True,
n_gpu_layers=81,
n_batch=1024,
n_ctx=8192,
)
self.llm_model = model_path
def create_index_for_query_engine(self, matched_path):
print(f"Attempting to read files from: {matched_path}")
documents = []
for root, dirs, files in os.walk(matched_path):
for file in files:
file_path = os.path.join(root, file)
try:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
doc = Document(text=content, metadata={"source": file_path})
documents.append(doc)
print(f"Successfully read file: {file_path}")
except Exception as e:
print(f"Error reading file {file_path}: {str(e)}")
print(f"Number of documents loaded: {len(documents)}")
storage_context = StorageContext.from_defaults()
nodes = SentenceSplitter(chunk_size=256, chunk_overlap=64, paragraph_separator="\n\n").get_nodes_from_documents(documents)
index = VectorStoreIndex(nodes, embed_model=self.embed_model)
query_engine = index.as_query_engine(
similarity_top_k=4, response_mode="tree_summarize"
)
index.storage_context.persist(persist_dir=self.PERSIST_DIR)
return query_engine
# here we're leveraging an already constructed and stored FAISS index
def load_index_for_query_engine(self):
storage_context = StorageContext.from_defaults(persist_dir=self.PERSIST_DIR)
index = load_index_from_storage(storage_context, embed_model=self.embed_model)
query_engine = index.as_query_engine(
similarity_top_k=4, response_mode="tree_summarize"
)
return query_engine
def generate_prompt(self, query_engine, message):
relevant_chunks = query_engine.retrieve(message)
print(f"Found: {len(relevant_chunks)} relevant chunks")
prompt = "Considera questo come tua base di conoscenza personale:\n==========Conoscenza===========\n"
for idx, chunk in enumerate(relevant_chunks):
print(f"{idx + 1}) {chunk.text[:64]}...")
prompt += chunk.text + "\n\n"
prompt += "\n======================\nDomanda: " + message
return prompt