Spaces:
Paused
Paused
import logging | |
import os | |
import shutil | |
import subprocess | |
import torch | |
from flask import Flask, jsonify, request | |
from langchain.chains import RetrievalQA | |
from langchain.embeddings import HuggingFaceInstructEmbeddings | |
# from langchain.embeddings import HuggingFaceEmbeddings | |
from run_localGPT import load_model | |
from prompt_template_utils import get_prompt_template | |
# from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler | |
from langchain.vectorstores import Chroma | |
from werkzeug.utils import secure_filename | |
from constants import CHROMA_SETTINGS, EMBEDDING_MODEL_NAME, PERSIST_DIRECTORY, MODEL_ID, MODEL_BASENAME | |
if torch.backends.mps.is_available(): | |
DEVICE_TYPE = "mps" | |
elif torch.cuda.is_available(): | |
DEVICE_TYPE = "cuda" | |
else: | |
DEVICE_TYPE = "cpu" | |
SHOW_SOURCES = True | |
logging.info(f"Running on: {DEVICE_TYPE}") | |
logging.info(f"Display Source Documents set to: {SHOW_SOURCES}") | |
EMBEDDINGS = HuggingFaceInstructEmbeddings(model_name=EMBEDDING_MODEL_NAME, model_kwargs={"device": DEVICE_TYPE}) | |
# uncomment the following line if you used HuggingFaceEmbeddings in the ingest.py | |
# EMBEDDINGS = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME) | |
# if os.path.exists(PERSIST_DIRECTORY): | |
# try: | |
# shutil.rmtree(PERSIST_DIRECTORY) | |
# except OSError as e: | |
# print(f"Error: {e.filename} - {e.strerror}.") | |
# else: | |
# print("The directory does not exist") | |
# run_langest_commands = ["python", "ingest.py"] | |
# if DEVICE_TYPE == "cpu": | |
# run_langest_commands.append("--device_type") | |
# run_langest_commands.append(DEVICE_TYPE) | |
# result = subprocess.run(run_langest_commands, capture_output=True) | |
# if result.returncode != 0: | |
# raise FileNotFoundError( | |
# "No files were found inside SOURCE_DOCUMENTS, please put a starter file inside before starting the API!" | |
# ) | |
# load the vectorstore | |
DB = Chroma( | |
persist_directory=PERSIST_DIRECTORY, | |
embedding_function=EMBEDDINGS, | |
client_settings=CHROMA_SETTINGS, | |
) | |
RETRIEVER = DB.as_retriever() | |
LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME) | |
prompt, memory = get_prompt_template(promptTemplate_type="llama", history=False) | |
QA = RetrievalQA.from_chain_type( | |
llm=LLM, | |
chain_type="stuff", | |
retriever=RETRIEVER, | |
return_source_documents=SHOW_SOURCES, | |
chain_type_kwargs={ | |
"prompt": prompt, | |
}, | |
) | |
app = Flask(__name__) | |
def delete_source_route(): | |
folder_name = "SOURCE_DOCUMENTS" | |
if os.path.exists(folder_name): | |
shutil.rmtree(folder_name) | |
os.makedirs(folder_name) | |
return jsonify({"message": f"Folder '{folder_name}' successfully deleted and recreated."}) | |
def save_document_route(): | |
if "document" not in request.files: | |
return "No document part", 400 | |
file = request.files["document"] | |
if file.filename == "": | |
return "No selected file", 400 | |
if file: | |
filename = secure_filename(file.filename) | |
folder_path = "SOURCE_DOCUMENTS" | |
if not os.path.exists(folder_path): | |
os.makedirs(folder_path) | |
file_path = os.path.join(folder_path, filename) | |
file.save(file_path) | |
return "File saved successfully", 200 | |
def run_ingest_route(): | |
global DB | |
global RETRIEVER | |
global QA | |
try: | |
if os.path.exists(PERSIST_DIRECTORY): | |
try: | |
shutil.rmtree(PERSIST_DIRECTORY) | |
except OSError as e: | |
print(f"Error: {e.filename} - {e.strerror}.") | |
else: | |
print("The directory does not exist") | |
run_langest_commands = ["python", "ingest.py"] | |
if DEVICE_TYPE == "cpu": | |
run_langest_commands.append("--device_type") | |
run_langest_commands.append(DEVICE_TYPE) | |
result = subprocess.run(run_langest_commands, capture_output=True) | |
if result.returncode != 0: | |
return "Script execution failed: {}".format(result.stderr.decode("utf-8")), 500 | |
# load the vectorstore | |
DB = Chroma( | |
persist_directory=PERSIST_DIRECTORY, | |
embedding_function=EMBEDDINGS, | |
client_settings=CHROMA_SETTINGS, | |
) | |
RETRIEVER = DB.as_retriever() | |
prompt, memory = get_prompt_template(promptTemplate_type="llama", history=False) | |
QA = RetrievalQA.from_chain_type( | |
llm=LLM, | |
chain_type="stuff", | |
retriever=RETRIEVER, | |
return_source_documents=SHOW_SOURCES, | |
chain_type_kwargs={ | |
"prompt": prompt, | |
}, | |
) | |
return "Script executed successfully: {}".format(result.stdout.decode("utf-8")), 200 | |
except Exception as e: | |
return f"Error occurred: {str(e)}", 500 | |
def prompt_route(): | |
global QA | |
user_prompt = request.form.get("user_prompt") | |
if user_prompt: | |
# print(f'User Prompt: {user_prompt}') | |
# Get the answer from the chain | |
res = QA(user_prompt) | |
answer, docs = res["result"], res["source_documents"] | |
prompt_response_dict = { | |
"Prompt": user_prompt, | |
"Answer": answer, | |
} | |
prompt_response_dict["Sources"] = [] | |
for document in docs: | |
prompt_response_dict["Sources"].append( | |
(os.path.basename(str(document.metadata["source"])), str(document.page_content)) | |
) | |
return jsonify(prompt_response_dict), 200 | |
else: | |
return "No user prompt received", 400 | |
if __name__ == "__main__": | |
logging.basicConfig( | |
format="%(asctime)s - %(levelname)s - %(filename)s:%(lineno)s - %(message)s", level=logging.INFO | |
) | |
app.run(debug=False, port=5110) | |