Spaces:
Paused
Paused
from flask import Flask, render_template, request, abort, redirect, url_for, Response | |
from werkzeug.security import generate_password_hash, check_password_hash | |
from werkzeug.exceptions import HTTPException | |
import os, threading, json, waitress, datetime | |
from llama_cpp import Llama | |
from dotenv import load_dotenv | |
load_dotenv() | |
#Variables | |
DEBUGMODEENABLED = (os.getenv('debugModeEnabled', 'False') == 'True') | |
llm = None | |
AlpacaLoaded = False | |
#Chat Functions | |
def load_alpaca(): | |
global llm, AlpacaLoaded | |
if not AlpacaLoaded: | |
print("Loading Alpaca...") | |
try: | |
llm = Llama(model_path="./resources/ggml-model-q4_0.bin", use_mmap=False, n_threads=2, verbose=False) #use_mlock=True | |
AlpacaLoaded = True | |
print("Done loading Alpaca.") | |
except AttributeError: | |
print("Error loading Alpaca. Please make sure you have the model file in the resources folder.") | |
else: | |
print("Alpaca already loaded.") | |
def getChatResponse(modelOutput): | |
return str(modelOutput["choices"][0]['message']['content']) | |
#Authentication Functions | |
def loadHashes(): | |
global hashesDict | |
try: | |
with open("resources/hashes.json", "r") as f: | |
hashesDict = json.load(f) | |
except FileNotFoundError: | |
hashesDict = {} | |
def saveHashes(): | |
global hashesDict | |
with open("resources/hashes.json", "w") as f: | |
json.dump(hashesDict, f) | |
def addHashes(username: str, password: str): | |
global hashesDict | |
hashesDict[username] = generate_password_hash(password, method='scrypt') | |
saveHashes() | |
def checkCredentials(username: str , password: str): | |
global hashesDict | |
if username in hashesDict: | |
return check_password_hash(hashesDict[username], password) | |
else: | |
return False | |
def verifyHeaders(): | |
#Check + Obtain Authorization header | |
try: | |
user, passw = request.headers['Authorization'].split(":") | |
except (KeyError, ValueError): | |
abort(401) | |
#Check if Authorization header is valid | |
credentialsValid = checkCredentials(user, passw) | |
if not credentialsValid: | |
abort(403) | |
else: | |
return user | |
loadHashes() | |
#addHashes("test", "test") | |
#General Functions | |
def getIsoTime(): | |
return str(datetime.datetime.now().isoformat()) | |
#Flask App | |
app = Flask(__name__) | |
def main(): | |
return """<!DOCTYPE HTML> | |
<html> | |
<head><meta name='color-scheme' content='dark'></head> | |
<body><p>Hello, World!</p></body> | |
</html>""" | |
def chat(): | |
if request.method == 'POST': | |
verifyHeaders() | |
print("Headers verified") | |
messages = request.get_json() | |
print("Got Message" + str(messages)) | |
if AlpacaLoaded: | |
modelOutput = llm.create_chat_completion(messages=messages, max_tokens=128) | |
responseMessage = modelOutput["choices"][0]['message'] | |
print(f"\n\nResponseMessage: {responseMessage}\n\n") | |
return Response(json.dumps(responseMessage, indent=2), content_type='application/json') | |
else: | |
abort(503, "Alpaca not loaded. Please wait a few seconds and try again.") | |
else: | |
return "Ready" if AlpacaLoaded else "Not Ready", 200 if AlpacaLoaded else 503 | |
def handle_exception(e): | |
return Response({"error": f"{e.code} - {e.name}", "message": e.description}, content_type='application/json'), e.code | |
if __name__ == '__main__': | |
t = threading.Thread(target=load_alpaca, daemon=True).start() | |
port = int(os.getenv("port", "8080")) | |
print("Server successfully started.") | |
if DEBUGMODEENABLED: | |
app.run(host='0.0.0.0', port=port) | |
else: | |
waitress.serve(app, host='0.0.0.0', port=port, url_scheme='https') |