|
import torch
|
|
import ollama
|
|
import os
|
|
from openai import OpenAI
|
|
import argparse
|
|
|
|
|
|
PINK = '\033[95m'
|
|
CYAN = '\033[96m'
|
|
YELLOW = '\033[93m'
|
|
NEON_GREEN = '\033[92m'
|
|
RESET_COLOR = '\033[0m'
|
|
|
|
|
|
def open_file(filepath):
|
|
with open(filepath, 'r', encoding='utf-8') as infile:
|
|
return infile.read()
|
|
|
|
|
|
def get_relevant_context(rewritten_input, vault_embeddings, vault_content, top_k=3):
|
|
if vault_embeddings.nelement() == 0:
|
|
return []
|
|
|
|
input_embedding = ollama.embeddings(model='mxbai-embed-large', prompt=rewritten_input)["embedding"]
|
|
|
|
cos_scores = torch.cosine_similarity(torch.tensor(input_embedding).unsqueeze(0), vault_embeddings)
|
|
|
|
top_k = min(top_k, len(cos_scores))
|
|
|
|
top_indices = torch.topk(cos_scores, k=top_k)[1].tolist()
|
|
|
|
relevant_context = [vault_content[idx].strip() for idx in top_indices]
|
|
return relevant_context
|
|
|
|
|
|
def ollama_chat(user_input, system_message, vault_embeddings, vault_content, ollama_model, conversation_history):
|
|
|
|
relevant_context = get_relevant_context(user_input, vault_embeddings, vault_content, top_k=3)
|
|
if relevant_context:
|
|
|
|
context_str = "\n".join(relevant_context)
|
|
print("Context Pulled from Documents: \n\n" + CYAN + context_str + RESET_COLOR)
|
|
else:
|
|
print(CYAN + "No relevant context found." + RESET_COLOR)
|
|
|
|
|
|
user_input_with_context = user_input
|
|
if relevant_context:
|
|
user_input_with_context = context_str + "\n\n" + user_input
|
|
|
|
|
|
conversation_history.append({"role": "user", "content": user_input_with_context})
|
|
|
|
|
|
messages = [
|
|
{"role": "system", "content": system_message},
|
|
*conversation_history
|
|
]
|
|
|
|
|
|
response = client.chat.completions.create(
|
|
model=ollama_model,
|
|
messages=messages
|
|
)
|
|
|
|
|
|
conversation_history.append({"role": "assistant", "content": response.choices[0].message.content})
|
|
|
|
|
|
return response.choices[0].message.content
|
|
|
|
def process_text_files(user_input):
|
|
text_parse_directory = os.path.join("local-rag", "text_parse")
|
|
temp_file_path = os.path.join("local-rag", "temp.txt")
|
|
|
|
|
|
if not os.path.exists(text_parse_directory):
|
|
print(f"Directory '{text_parse_directory}' does not exist.")
|
|
return False
|
|
|
|
|
|
if not os.path.exists(temp_file_path):
|
|
print("temp.txt does not exist.")
|
|
return False
|
|
|
|
|
|
with open(temp_file_path, 'r', encoding='utf-8') as temp_file:
|
|
first_line = temp_file.readline().strip()
|
|
|
|
|
|
text_files = [f for f in os.listdir(text_parse_directory) if f.endswith('.txt')]
|
|
|
|
|
|
if f"{first_line}" not in text_files:
|
|
print(f"No matching file found for '{first_line}.txt' in text_parse directory.")
|
|
return False
|
|
|
|
|
|
file_path = os.path.join(text_parse_directory, f"{first_line}")
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|
lines = f.readlines()
|
|
|
|
|
|
|
|
if lines[-2].strip() == "====================NOT FINISHED====================":
|
|
print(f"'{first_line}' contains the 'NOT FINISHED' flag. Computing embeddings.")
|
|
|
|
vault_content = []
|
|
if os.path.exists(temp_file_path):
|
|
with open(temp_file_path, "r", encoding='utf-8') as vault_file:
|
|
vault_content = vault_file.readlines()
|
|
|
|
|
|
|
|
vault_embeddings = []
|
|
for content in vault_content:
|
|
response = ollama.embeddings(model='mxbai-embed-large', prompt=content)
|
|
vault_embeddings.append(response["embedding"])
|
|
|
|
|
|
vault_embeddings_tensor = torch.tensor(vault_embeddings)
|
|
print("Embeddings for each line in the vault:")
|
|
print(vault_embeddings_tensor)
|
|
|
|
|
|
with open(os.path.join(text_parse_directory, f"{first_line}_embedding.pt"), "wb") as tensor_file:
|
|
torch.save(vault_embeddings_tensor, tensor_file)
|
|
|
|
|
|
with open(file_path, 'w', encoding='utf-8') as f:
|
|
f.writelines(lines[:-1])
|
|
|
|
else:
|
|
print(f"'{first_line}' does not contain the 'NOT FINISHED' flag or is already complete. Loading tensor if it exists.")
|
|
|
|
|
|
tensor_file_path = os.path.join(text_parse_directory, f"{first_line}_embedding.pt")
|
|
if os.path.exists(tensor_file_path):
|
|
vault_embeddings_tensor = torch.load(tensor_file_path)
|
|
print("Loaded Vault Embedding Tensor:")
|
|
print(vault_embeddings_tensor)
|
|
|
|
vault_content = []
|
|
|
|
if os.path.exists(temp_file_path):
|
|
with open(temp_file_path, "r", encoding='utf-8') as vault_file:
|
|
vault_content = vault_file.readlines()
|
|
|
|
else:
|
|
print(f"No tensor file found for '{text_files}'.")
|
|
|
|
|
|
|
|
|
|
conversation_history = []
|
|
system_message = "You are a helpful assistant that is an expert at extracting the most useful information from a given text"
|
|
|
|
response = ollama_chat(user_input, system_message, vault_embeddings_tensor, vault_content, args.model, conversation_history)
|
|
|
|
return response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description="Ollama Chat")
|
|
parser.add_argument("--model", default="llama3", help="Ollama model to use (default: llama3)")
|
|
args = parser.parse_args()
|
|
|
|
|
|
client = OpenAI(
|
|
base_url='http://localhost:11434/v1',
|
|
api_key='llama3'
|
|
)
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
print(process_text_files("tell me about iterators"))
|
|
|
|
|
|
|