File size: 3,566 Bytes
520af46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from flask import Flask, render_template, request, abort, redirect, url_for, Response
from werkzeug.security import generate_password_hash, check_password_hash
from werkzeug.exceptions import HTTPException

import os, threading, json, waitress, datetime
from llama_cpp import Llama
from dotenv import load_dotenv
load_dotenv()

#Variables
DEBUGMODEENABLED = (os.getenv('debugModeEnabled', 'False') == 'True')
llm = None
AlpacaLoaded = False

#Chat Functions
def load_alpaca():
  global llm, AlpacaLoaded
  if not AlpacaLoaded:
    print("Loading Alpaca...")
    try:
      llm = Llama(model_path="./resources/ggml-model-q4_0.bin", use_mmap=False, n_threads=2, verbose=False) #use_mlock=True
      AlpacaLoaded = True
      print("Done loading Alpaca.")
    except AttributeError:
      print("Error loading Alpaca. Please make sure you have the model file in the resources folder.")
  else:
    print("Alpaca already loaded.")

def getChatResponse(modelOutput):
  return str(modelOutput["choices"][0]['message']['content'])


#Authentication Functions
def loadHashes():
  global hashesDict
  try:
    with open("resources/hashes.json", "r") as f:
      hashesDict = json.load(f)
  except FileNotFoundError:
    hashesDict = {}
  
def saveHashes():
  global hashesDict
  with open("resources/hashes.json", "w") as f:
    json.dump(hashesDict, f)

def addHashes(username: str, password: str):
  global hashesDict
  hashesDict[username] = generate_password_hash(password, method='scrypt')
  saveHashes()

def checkCredentials(username: str , password: str):
  global hashesDict
  if username in hashesDict:
    return check_password_hash(hashesDict[username], password)
  else:
    return False

def verifyHeaders():
  #Check + Obtain Authorization header
  try:
    user, passw = request.headers['Authorization'].split(":")
  except (KeyError, ValueError):
    abort(401)

  #Check if Authorization header is valid
  credentialsValid = checkCredentials(user, passw)
  if not credentialsValid:
    abort(403)
  else:
    return user


loadHashes()
#addHashes("test", "test")

#General Functions
def getIsoTime():
  return str(datetime.datetime.now().isoformat())

#Flask App
app = Flask(__name__)

@app.route('/')
def main():
  return """<!DOCTYPE HTML>
<html>
  <head><meta name='color-scheme' content='dark'></head>
  <body><p>Hello, World!</p></body>
</html>"""

@app.route('/chat', methods=['GET', 'POST'])
def chat():
  if request.method == 'POST':
    verifyHeaders()
    print("Headers verified")
    messages = request.get_json()
    print("Got Message" + str(messages))

    if AlpacaLoaded:
      modelOutput = llm.create_chat_completion(messages=messages, max_tokens=128)
      responseMessage = modelOutput["choices"][0]['message']
      print(f"\n\nResponseMessage: {responseMessage}\n\n")
      return Response(json.dumps(responseMessage, indent=2), content_type='application/json')
    else:
      abort(503, "Alpaca not loaded. Please wait a few seconds and try again.")
  else:
    return "Ready" if AlpacaLoaded else "Not Ready", 200 if AlpacaLoaded else 503

@app.errorhandler(HTTPException)
def handle_exception(e):
  return Response({"error": f"{e.code} - {e.name}", "message": e.description}, content_type='application/json'), e.code

if __name__ == '__main__':
  t = threading.Thread(target=load_alpaca, daemon=True).start()

  port = int(os.getenv("port", "8080"))
  print("Server successfully started.")
  if DEBUGMODEENABLED:
    app.run(host='0.0.0.0', port=port)
  else:
    waitress.serve(app, host='0.0.0.0', port=port, url_scheme='https')