ChadAlpaca-Flask / main.py
PawinC's picture
Going Back...
766da77
from flask import Flask, request, abort, Response
from werkzeug.security import generate_password_hash, check_password_hash
from werkzeug.exceptions import HTTPException
import os, threading, json, waitress, datetime, traceback
from llama_cpp import Llama
from dotenv import load_dotenv
load_dotenv()
import sentry_sdk
from flask import Flask
from sentry_sdk.integrations.flask import FlaskIntegration
sentry_sdk.init(
dsn="https://5dcf8a99012c4c86b9b1f0293f6b4c2e@o4505516024004608.ingest.sentry.io/4505541971935232",
integrations=[
FlaskIntegration(),
],
# Set traces_sample_rate to 1.0 to capture 100%
# of transactions for performance monitoring.
# We recommend adjusting this value in production.
traces_sample_rate=1.0
)
#Variables
DEBUGMODEENABLED = (os.getenv('debugModeEnabled', 'False') == 'True')
modelName = "vicuna"
llm = None
AlpacaLoaded = False
#Chat Functions
def load_alpaca():
global llm, AlpacaLoaded, modelName
if not AlpacaLoaded:
print("Loading Alpaca...")
try:
llm = Llama(model_path=f"./resources/{modelName}-ggml-model-q4.bin", use_mmap=False, n_threads=2, verbose=False, n_ctx=2048) #use_mlock=True
AlpacaLoaded = True
print("Done loading Alpaca.")
return "Done"
except AttributeError:
print("Error loading Alpaca. Please make sure you have the model file in the resources folder.")
return "Error"
else:
print("Alpaca already loaded.")
return "Already Loaded"
def getChatResponse(modelOutput):
return str(modelOutput["choices"][0]['message']['content'])
def reload_alpaca():
global llm, AlpacaLoaded, modelName
if AlpacaLoaded:
llm = None
input("Pleease confirm that the memory is cleared!")
AlpacaLoaded = False
load_alpaca()
return "Done"
#Authentication Functions
def loadHashes():
global hashesDict
try:
with open("resources/hashes.json", "r") as f:
hashesDict = json.load(f)
except FileNotFoundError:
hashesDict = {}
def saveHashes():
global hashesDict
with open("resources/hashes.json", "w") as f:
json.dump(hashesDict, f)
def addHashes(username: str, password: str):
global hashesDict
hashesDict[username] = generate_password_hash(password, method='scrypt')
saveHashes()
def checkCredentials(username: str , password: str):
global hashesDict
if username in hashesDict:
return check_password_hash(hashesDict[username], password)
else:
return False
def verifyHeaders():
#Check + Obtain Authorization header
try:
user, passw = request.headers['Authorization'].split(":")
except (KeyError, ValueError):
abort(401)
#Check if Authorization header is valid
credentialsValid = checkCredentials(user, passw)
if not credentialsValid:
abort(403)
else:
return user
loadHashes()
#addHashes("test", "test")
#General Functions
def getIsoTime():
return str(datetime.datetime.now().isoformat())
#Flask App
app = Flask(__name__)
@app.route('/')
def main():
return """<!DOCTYPE HTML>
<html>
<head><meta name='color-scheme' content='dark'></head>
<body><p>Hello, World!</p></body>
</html>"""
@app.route('/chat', methods=['GET', 'POST'])
def chat():
if request.method == 'POST':
print("Chat Completion Requested.")
verifyHeaders()
print("Headers verified")
messages = request.get_json()
print("Got Message" + str(messages))
if AlpacaLoaded:
modelOutput = llm.create_chat_completion(messages=messages, max_tokens=1024)
responseMessage = modelOutput["choices"][0]['message']
print(f"\n\nResponseMessage: {responseMessage}\n\n")
return Response(json.dumps(responseMessage, indent=2), content_type='application/json')
else:
print("Alpaca not loaded. ")
abort(503, "Alpaca not loaded. Please wait a few seconds and try again.")
else:
return "Ready" if AlpacaLoaded else "Not Ready", 200 if AlpacaLoaded else 503
@app.route('/sentry_check')
def trigger_error():
division_by_zero = 1 / 0
@app.errorhandler(HTTPException)
def handle_exception(e):
errorInfo = json.dumps({"error": f"{e.code} - {e.name}", "message": e.description}, indent=2)
return Response(errorInfo, content_type='application/json'), e.code
@app.errorhandler(Exception)
def handle_errors(e):
print(f"INTERNAL SERVER ERROR 500 @ {request.path}")
exceptionInfo = f"{type(e).__name__}: {str(e)}"
errorTraceback = traceback.format_exc()
print(errorTraceback)
sentry_sdk.capture_exception(e)
errorInfo = json.dumps({"error": f"500 - Internal Server Error", "message": exceptionInfo}, indent=2)
return Response(errorInfo, content_type='application/json'), 500
if __name__ == '__main__':
threading.Thread(target=load_alpaca, daemon=True).start()
port = int(os.getenv("port", "8080"))
print("Server successfully started.")
if DEBUGMODEENABLED:
app.run(host='0.0.0.0', port=port)
else:
waitress.serve(app, host='0.0.0.0', port=port, url_scheme='https')