from fastapi import FastAPI, HTTPException from pydantic import BaseModel import pickle import uvicorn import logging import os import shutil import subprocess import torch from flask import Flask, jsonify, request, render_template from langchain.chains import RetrievalQA from langchain.embeddings import HuggingFaceInstructEmbeddings # from langchain.embeddings import HuggingFaceEmbeddings from run_localGPT import load_model from prompt_template_utils import get_prompt_template # from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler from langchain.vectorstores import Chroma from werkzeug.utils import secure_filename from constants import CHROMA_SETTINGS, EMBEDDING_MODEL_NAME, PERSIST_DIRECTORY, MODEL_ID, MODEL_BASENAME if torch.backends.mps.is_available(): DEVICE_TYPE = "mps" elif torch.cuda.is_available(): DEVICE_TYPE = "cuda" else: DEVICE_TYPE = "cpu" SHOW_SOURCES = True logging.info(f"Running on: {DEVICE_TYPE}") logging.info(f"Display Source Documents set to: {SHOW_SOURCES}") EMBEDDINGS = HuggingFaceInstructEmbeddings(model_name=EMBEDDING_MODEL_NAME, model_kwargs={"device": DEVICE_TYPE}) # load the vectorstore DB = Chroma( persist_directory=PERSIST_DIRECTORY, embedding_function=EMBEDDINGS, client_settings=CHROMA_SETTINGS, ) RETRIEVER = DB.as_retriever() LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME) prompt, memory = get_prompt_template(promptTemplate_type="llama", history=False) QA = RetrievalQA.from_chain_type( llm=LLM, chain_type="stuff", retriever=RETRIEVER, return_source_documents=SHOW_SOURCES, chain_type_kwargs={ "prompt": prompt, }, ) class Predict(BaseModel): prompt: str app = FastAPI() @app.get("/") def root(): return {"API": "An API for Sepsis Prediction."} @app.post('/predict') async def predict(data: Predict): global QA user_prompt = data.prompt if user_prompt: # print(f'User Prompt: {user_prompt}') # Get the answer from the chain res = QA(user_prompt) answer, docs = res["result"], res["source_documents"] prompt_response_dict = { "Prompt": user_prompt, "Answer": answer, } prompt_response_dict["Sources"] = [] for document in docs: prompt_response_dict["Sources"].append( (os.path.basename(str(document.metadata["source"])), str(document.page_content)) ) return jsonify(prompt_response_dict) else: raise HTTPException(status_code=400, detail="Prompt Incorrect")