Spaces:
Paused
Paused
import os | |
import glob | |
import shutil | |
import subprocess | |
import asyncio | |
from typing import Any, Dict, List | |
from fastapi import FastAPI, HTTPException, UploadFile, WebSocket, WebSocketDisconnect | |
from fastapi.staticfiles import StaticFiles | |
from pydantic import BaseModel | |
# import torch | |
from langchain.chains import RetrievalQA | |
from langchain.embeddings import HuggingFaceInstructEmbeddings | |
from langchain.prompts import PromptTemplate | |
from langchain.memory import ConversationBufferMemory | |
from langchain.callbacks.base import BaseCallbackHandler | |
from langchain.schema import LLMResult | |
# from langchain.embeddings import HuggingFaceEmbeddings | |
from load_models import load_model | |
# from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler | |
from langchain.vectorstores import Chroma | |
from constants import CHROMA_SETTINGS, EMBEDDING_MODEL_NAME, PERSIST_DIRECTORY, MODEL_ID, MODEL_BASENAME, PATH_NAME_SOURCE_DIRECTORY | |
class Predict(BaseModel): | |
prompt: str | |
class Delete(BaseModel): | |
filename: str | |
# if torch.backends.mps.is_available(): | |
# DEVICE_TYPE = "mps" | |
# elif torch.cuda.is_available(): | |
# DEVICE_TYPE = "cuda" | |
# else: | |
# DEVICE_TYPE = "cpu" | |
DEVICE_TYPE = "cuda" | |
SHOW_SOURCES = True | |
EMBEDDINGS = HuggingFaceInstructEmbeddings(model_name=EMBEDDING_MODEL_NAME, model_kwargs={"device": DEVICE_TYPE}) | |
# load the vectorstore | |
DB = Chroma( | |
persist_directory=PERSIST_DIRECTORY, | |
embedding_function=EMBEDDINGS, | |
client_settings=CHROMA_SETTINGS, | |
) | |
RETRIEVER = DB.as_retriever() | |
class MyCustomSyncHandler(BaseCallbackHandler): | |
def __init__(self): | |
self.end = False | |
self.callback = None | |
def on_llm_start( | |
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any | |
) -> None: | |
self.end = False | |
self.callback = None | |
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: | |
self.end = True | |
def on_llm_new_token(self, token: str, **kwargs) -> Any: | |
if self.callback : | |
self.callback(token) | |
print(token) | |
# Create State | |
handlerToken = MyCustomSyncHandler() | |
LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=True, callbacks=[handlerToken]) | |
template = """You are a helpful, respectful and honest assistant. | |
Always answer in the most helpful and safe way possible without trying to make up an answer, if you don't know the answer just say "I don't know" and don't share false information or topics that were not provided in your training. Use a maximum of 15 sentences. Your answer should be as concise and clear as possible. Always say "thank you for asking!" at the end of your answer. | |
Context: {context} | |
Question: {question} | |
""" | |
memory = ConversationBufferMemory(input_key="question", memory_key="history") | |
QA_CHAIN_PROMPT = PromptTemplate(input_variables=["context", "question"], template=template) | |
QA = RetrievalQA.from_chain_type( | |
llm=LLM, | |
chain_type="stuff", | |
retriever=RETRIEVER, | |
return_source_documents=SHOW_SOURCES, | |
chain_type_kwargs={ | |
"prompt": QA_CHAIN_PROMPT, | |
}, | |
) | |
app = FastAPI(title="homepage-app") | |
api_app = FastAPI(title="api app") | |
app.mount("/api", api_app, name="api") | |
app.mount("/", StaticFiles(directory="static",html = True), name="static") | |
def run_ingest_route(): | |
global DB | |
global RETRIEVER | |
global QA | |
try: | |
if os.path.exists(PERSIST_DIRECTORY): | |
try: | |
shutil.rmtree(PERSIST_DIRECTORY) | |
except OSError as e: | |
raise HTTPException(status_code=500, detail=f"Error: {e.filename} - {e.strerror}.") | |
else: | |
raise HTTPException(status_code=500, detail="The directory does not exist") | |
run_langest_commands = ["python", "ingest.py"] | |
if DEVICE_TYPE == "cpu": | |
run_langest_commands.append("--device_type") | |
run_langest_commands.append(DEVICE_TYPE) | |
result = subprocess.run(run_langest_commands, capture_output=True) | |
if result.returncode != 0: | |
raise HTTPException(status_code=400, detail="Script execution failed: {}") | |
# load the vectorstore | |
DB = Chroma( | |
persist_directory=PERSIST_DIRECTORY, | |
embedding_function=EMBEDDINGS, | |
client_settings=CHROMA_SETTINGS, | |
) | |
RETRIEVER = DB.as_retriever() | |
QA = RetrievalQA.from_chain_type( | |
llm=LLM, | |
chain_type="stuff", | |
retriever=RETRIEVER, | |
return_source_documents=SHOW_SOURCES, | |
chain_type_kwargs={ | |
"prompt": QA_CHAIN_PROMPT, | |
"memory": memory | |
}, | |
) | |
return {"response": "The training was successfully completed"} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error occurred: {str(e)}") | |
def get_files(): | |
upload_dir = os.path.join(os.getcwd(), PATH_NAME_SOURCE_DIRECTORY) | |
files = glob.glob(os.path.join(upload_dir, '*')) | |
return {"directory": upload_dir, "files": files} | |
def delete_source_route(data: Delete): | |
filename = data.filename | |
path_source_documents = os.path.join(os.getcwd(), PATH_NAME_SOURCE_DIRECTORY) | |
file_to_delete = f"{path_source_documents}/{filename}" | |
if os.path.exists(file_to_delete): | |
try: | |
os.remove(file_to_delete) | |
print(f"{file_to_delete} has been deleted.") | |
return {"message": f"{file_to_delete} has been deleted."} | |
except OSError as e: | |
raise HTTPException(status_code=400, detail=print(f"error: {e}.")) | |
else: | |
raise HTTPException(status_code=400, detail=print(f"The file {file_to_delete} does not exist.")) | |
async def predict(data: Predict): | |
global QA | |
user_prompt = data.prompt | |
if user_prompt: | |
res = QA(user_prompt) | |
answer, docs = res["result"], res["source_documents"] | |
prompt_response_dict = { | |
"Prompt": user_prompt, | |
"Answer": answer, | |
} | |
prompt_response_dict["Sources"] = [] | |
for document in docs: | |
prompt_response_dict["Sources"].append( | |
(os.path.basename(str(document.metadata["source"])), str(document.page_content)) | |
) | |
return {"response": prompt_response_dict} | |
else: | |
raise HTTPException(status_code=400, detail="Prompt Incorrect") | |
async def create_upload_file(file: UploadFile): | |
# Get the file size (in bytes) | |
file.file.seek(0, 2) | |
file_size = file.file.tell() | |
# move the cursor back to the beginning | |
await file.seek(0) | |
if file_size > 10 * 1024 * 1024: | |
# more than 10 MB | |
raise HTTPException(status_code=400, detail="File too large") | |
content_type = file.content_type | |
if content_type not in [ | |
"text/plain", | |
"text/markdown", | |
"text/x-markdown", | |
"text/csv", | |
"application/msword", | |
"application/pdf", | |
"application/vnd.ms-excel", | |
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", | |
"application/vnd.openxmlformats-officedocument.wordprocessingml.document", | |
"text/x-python", | |
"application/x-python-code"]: | |
raise HTTPException(status_code=400, detail="Invalid file type") | |
upload_dir = os.path.join(os.getcwd(), PATH_NAME_SOURCE_DIRECTORY) | |
if not os.path.exists(upload_dir): | |
os.makedirs(upload_dir) | |
dest = os.path.join(upload_dir, file.filename) | |
with open(dest, "wb") as buffer: | |
shutil.copyfileobj(file.file, buffer) | |
return {"filename": file.filename} | |
async def websocket_endpoint(websocket: WebSocket, client_id: int): | |
global QA | |
await websocket.accept() | |
oldReceiveText = '' | |
try: | |
while True: | |
prompt = await websocket.receive_text() | |
if (oldReceiveText != prompt): | |
handlerToken.callback = websocket.send_text | |
oldReceiveText = prompt | |
await QA(prompt) | |
except WebSocketDisconnect: | |
print('disconnect') | |
except RuntimeError as error: | |
print(error) | |