# ---------------------------------------- IMPORTS ---------------------------------------- # # import flask and flask_cors to host the api from flask import Flask, request, jsonify, render_template from flask_cors import CORS # import the api class from assets.source import api, non_streamed_format # import addon from assets.source.addons import * # here we only use 'create_cloudflare_tunnel' and 'translate' from the addons # logging module for debugging import logging # json module to parse json from json import loads # ---------------------------------------- CONFIGURE LOCAL SERVER ---------------------------------------- # # create flask app app = Flask(__name__) app.template_folder = "assets/templates" # enable cors CORS(app) # ---------------------------------------- READ FROM CONFIG FILE ---------------------------------------- # with (open("assets/config.json", "r")) as f: config_file = loads(f.read()) # copy constants over DEBUG: bool = config_file.get("DEBUG", False) PORT: int = config_file.get("PORT", 5000) HOST: str = config_file.get("HOST", "") # check if user wants to use a global server too if config_file["use_global"]: # create a cloudflare tunnel create_cloudflare_tunnel(PORT) # ---------------------------------------- LOGGING CONFIG ---------------------------------------- # # set logging level logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(message)s') # ---------------------------------------- ROUTES ---------------------------------------- # # chat generaiton route @app.route("/chat/completions", methods=["POST"]) def chat(): # get request data data = request.get_json() # get messages messages = message_translation(data["messages"]) if config_file["use_addons"] else data["messages"] # get model model = translate(data["model"]) if config_file["use_addons"] else data["model"] # get max tokens max_tokens = data.get("max_tokens", 150) # top p and top k top_p = data.get("top_p", 0.99) top_k = data.get("top_k", 50) # temperature, frequency penalty and presence penalty temperature = data.get("temperature", 0.6) # frequency penalty frequency_penalty = data.get("frequency_penalty", 1) # presence penalty presence_penalty = data.get("presence_penalty", 1) # streaming function. uses text/event-stream instead of application/json def stream(): # generate chat for chunk in api.chat(messages, model, stream=True, max_tokens=max_tokens, top_p=top_p, temperature=temperature, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, top_k=top_k ): # yield chat #print(chunk) yield chunk + b'\n\n' # in the end, return done yield b'data: [DONE]' # check if user wants to stream if data.get("stream"): # log logging.info(f"Streaming requested for model {model}\n") # return stream return app.response_class(stream(), mimetype='text/event-stream') # even if not, stream but collect all data to a full string else: # log logging.info(f"Non-streaming requested for model {model}\n") # pre-init full: str = "" # generate chat for chunk in api.chat(messages, model, stream=True, max_tokens=max_tokens, top_p=top_p, temperature=temperature, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, top_k=top_k ): try: # append chunk full += loads(chunk.decode("utf-8").removeprefix('data: '))["choices"][0]["delta"]["content"] except: pass # return full return jsonify(non_streamed_format(model, full)) # route to get all models @app.route("/models", methods=["GET"]) def get_models(): # return models return jsonify(api.get_models()) # root route to check if api is online @app.route("/", methods=["GET"]) def root(): # return root return render_template("index.html") # ---------------------------------------- ERROR HANDLING ---------------------------------------- # @app.errorhandler(403) def forbidden(error): # return 403 return jsonify( {"status": False}, {'error': [ {'message': 'Something went wrong, the API was blocked from sending a request to the DeepInfra API. Please try again later.'}, {'tpye': 'forbidden'}, {'error': f'{error}'} ]}, {'hint': 'please report issues on the github page'} ), 403 @app.errorhandler(500) def internal_server_error(error): # return 500 return jsonify( {"status": False}, {'error': [ {'message': 'Something went wrong, the API was unable to complete your request. Please try again later.'}, {'tpye': 'internal server error'}, {'error': f'{error}'} ]}, {'hint': 'please report issues on the github page'} ), 500 # ---------------------------------------- START API ---------------------------------------- # # start the api if __name__ == "__main__": app.run(debug=DEBUG, port=PORT, host=HOST) # Path: app.py