import logging import os import shutil import subprocess import torch from flask import Flask, jsonify, request, render_template from langchain.chains import RetrievalQA from langchain.embeddings import HuggingFaceInstructEmbeddings # from langchain.embeddings import HuggingFaceEmbeddings from run_localGPT import load_model from prompt_template_utils import get_prompt_template # from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler from langchain.vectorstores import Chroma from werkzeug.utils import secure_filename from constants import CHROMA_SETTINGS, EMBEDDING_MODEL_NAME, PERSIST_DIRECTORY, MODEL_ID, MODEL_BASENAME if torch.backends.mps.is_available(): DEVICE_TYPE = "mps" elif torch.cuda.is_available(): DEVICE_TYPE = "cuda" else: DEVICE_TYPE = "cpu" SHOW_SOURCES = True logging.info(f"Running on: {DEVICE_TYPE}") logging.info(f"Display Source Documents set to: {SHOW_SOURCES}") EMBEDDINGS = HuggingFaceInstructEmbeddings(model_name=EMBEDDING_MODEL_NAME, model_kwargs={"device": DEVICE_TYPE}) # uncomment the following line if you used HuggingFaceEmbeddings in the ingest.py # EMBEDDINGS = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME) # if os.path.exists(PERSIST_DIRECTORY): # try: # shutil.rmtree(PERSIST_DIRECTORY) # except OSError as e: # print(f"Error: {e.filename} - {e.strerror}.") # else: # print("The directory does not exist") # run_langest_commands = ["python", "ingest.py"] # if DEVICE_TYPE == "cpu": # run_langest_commands.append("--device_type") # run_langest_commands.append(DEVICE_TYPE) # result = subprocess.run(run_langest_commands, capture_output=True) # if result.returncode != 0: # raise FileNotFoundError( # "No files were found inside SOURCE_DOCUMENTS, please put a starter file inside before starting the API!" # ) # load the vectorstore DB = Chroma( persist_directory=PERSIST_DIRECTORY, embedding_function=EMBEDDINGS, client_settings=CHROMA_SETTINGS, ) RETRIEVER = DB.as_retriever() LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME) prompt, memory = get_prompt_template(promptTemplate_type="llama", history=False) QA = RetrievalQA.from_chain_type( llm=LLM, chain_type="stuff", retriever=RETRIEVER, return_source_documents=SHOW_SOURCES, chain_type_kwargs={ "prompt": prompt, }, ) app = Flask(__name__) @app.route("/") def index(): return render_template("home.html") @app.route("/api/delete_source", methods=["GET"]) def delete_source_route(): folder_name = "SOURCE_DOCUMENTS" if os.path.exists(folder_name): shutil.rmtree(folder_name) os.makedirs(folder_name) return jsonify({"message": f"Folder '{folder_name}' successfully deleted and recreated."}) @app.route("/api/save_document", methods=["GET", "POST"]) def save_document_route(): if "document" not in request.files: return "No document part", 400 file = request.files["document"] if file.filename == "": return "No selected file", 400 if file: filename = secure_filename(file.filename) folder_path = "SOURCE_DOCUMENTS" if not os.path.exists(folder_path): os.makedirs(folder_path) file_path = os.path.join(folder_path, filename) file.save(file_path) return "File saved successfully", 200 @app.route("/api/run_ingest", methods=["GET"]) def run_ingest_route(): global DB global RETRIEVER global QA try: if os.path.exists(PERSIST_DIRECTORY): try: shutil.rmtree(PERSIST_DIRECTORY) except OSError as e: print(f"Error: {e.filename} - {e.strerror}.") else: print("The directory does not exist") run_langest_commands = ["python", "ingest.py"] if DEVICE_TYPE == "cpu": run_langest_commands.append("--device_type") run_langest_commands.append(DEVICE_TYPE) result = subprocess.run(run_langest_commands, capture_output=True) if result.returncode != 0: return "Script execution failed: {}".format(result.stderr.decode("utf-8")), 500 # load the vectorstore DB = Chroma( persist_directory=PERSIST_DIRECTORY, embedding_function=EMBEDDINGS, client_settings=CHROMA_SETTINGS, ) RETRIEVER = DB.as_retriever() prompt, memory = get_prompt_template(promptTemplate_type="llama", history=False) QA = RetrievalQA.from_chain_type( llm=LLM, chain_type="stuff", retriever=RETRIEVER, return_source_documents=SHOW_SOURCES, chain_type_kwargs={ "prompt": prompt, }, ) return "Script executed successfully: {}".format(result.stdout.decode("utf-8")), 200 except Exception as e: return f"Error occurred: {str(e)}", 500 @app.route("/api/prompt_route", methods=["GET", "POST"]) def prompt_route(): global QA user_prompt = request.form.get("user_prompt") if user_prompt: # print(f'User Prompt: {user_prompt}') # Get the answer from the chain res = QA(user_prompt) answer, docs = res["result"], res["source_documents"] prompt_response_dict = { "Prompt": user_prompt, "Answer": answer, } prompt_response_dict["Sources"] = [] for document in docs: prompt_response_dict["Sources"].append( (os.path.basename(str(document.metadata["source"])), str(document.page_content)) ) return jsonify(prompt_response_dict), 200 else: return "No user prompt received", 400 if __name__ == "__main__": app.run(host="0.0.0.0", port=5110)