Spaces:
Sleeping
Sleeping
from flask import Flask, request, Response | |
import logging | |
from llama_cpp import Llama | |
import threading | |
from huggingface_hub import snapshot_download#, Repository | |
import huggingface_hub | |
import gc | |
import os.path | |
import xml.etree.ElementTree as ET | |
from apscheduler.schedulers.background import BackgroundScheduler | |
from datetime import datetime, timedelta | |
from llm_backend import LlmBackend | |
import json | |
llm = LlmBackend() | |
_lock = threading.Lock() | |
SYSTEM_PROMPT = os.environ.get('SYSTEM_PROMPT') or "Ты — русскоязычный автоматический ассистент. Ты максимально точно и отвечаешь на запросы пользователя, используя русский язык." | |
CONTEXT_SIZE = os.environ.get('CONTEXT_SIZE') or 500 | |
ENABLE_GPU = os.environ.get('ENABLE_GPU') or False | |
GPU_LAYERS = os.environ.get('GPU_LAYERS') or 0 | |
N_GQA = os.environ.get('N_GQA') or None #must be set to 8 for 70b models | |
CHAT_FORMAT = os.environ.get('CHAT_FORMAT') or 'llama-2' | |
# Create a lock object | |
lock = threading.Lock() | |
app = Flask(__name__) | |
# Configure Flask logging | |
app.logger.setLevel(logging.DEBUG) | |
# Variable to store the last request time | |
last_request_time = datetime.now() | |
# Initialize the model when the application starts | |
#model_path = "../models/model-q4_K.gguf" # Replace with the actual model path | |
#model_name = "model/ggml-model-q4_K.gguf" | |
#repo_name = "IlyaGusev/saiga2_13b_gguf" | |
#model_name = "model-q4_K.gguf" | |
#epo_name = "IlyaGusev/saiga2_70b_gguf" | |
#model_name = "ggml-model-q4_1.gguf" | |
repo_name = "IlyaGusev/saiga2_7b_gguf" | |
model_name = "model-q4_K.gguf" | |
local_dir = '.' | |
if os.path.isdir('/data'): | |
app.logger.info('Persistent storage enabled') | |
model = None | |
MODEL_PATH = snapshot_download(repo_id=repo_name, allow_patterns=model_name) + '/' + model_name | |
app.logger.info('Model path: ' + MODEL_PATH) | |
DATASET_REPO_URL = "https://huggingface.co/datasets/muryshev/saiga-chat" | |
DATA_FILENAME = "data-saiga-cuda-release.xml" | |
DATA_FILE = os.path.join("dataset", DATA_FILENAME) | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
app.logger.info("hfh: "+huggingface_hub.__version__) | |
# repo = Repository( | |
# local_dir="dataset", clone_from=DATASET_REPO_URL, use_auth_token=HF_TOKEN | |
# ) | |
# def log(req: str = '', resp: str = ''): | |
# if req or resp: | |
# element = ET.Element("row", {"time": str(datetime.now()) }) | |
# req_element = ET.SubElement(element, "request") | |
# req_element.text = req | |
# resp_element = ET.SubElement(element, "response") | |
# resp_element.text = resp | |
# with open(DATA_FILE, "ab+") as xml_file: | |
# xml_file.write(ET.tostring(element, encoding="utf-8")) | |
# commit_url = repo.push_to_hub() | |
# app.logger.info(commit_url) | |
def generate_tokens(model, generator): | |
global stop_generation | |
app.logger.info('generate_tokens started') | |
with lock: | |
try: | |
for token in generator: | |
if token == model.token_eos() or stop_generation: | |
stop_generation = False | |
app.logger.info('End generating') | |
yield b'' # End of chunk | |
break | |
token_str = model.detokenize([token])#.decode("utf-8", errors="ignore") | |
yield token_str | |
except Exception as e: | |
app.logger.info('generator exception') | |
app.logger.info(e) | |
yield b'' # End of chunk | |
def handler_change_context_size(): | |
global stop_generation, model | |
stop_generation = True | |
new_size = int(request.args.get('size', CONTEXT_SIZE)) | |
init_model(new_size, ENABLE_GPU, GPU_LAYERS) | |
return Response('Size changed', content_type='text/plain') | |
def handler_stop_generation(): | |
global stop_generation | |
stop_generation = True | |
return Response('Stopped', content_type='text/plain') | |
def generate_unknown_response(): | |
app.logger.info('unknown method: '+request.method) | |
try: | |
request_payload = request.get_json() | |
app.logger.info('payload: '+request.get_json()) | |
except Exception as e: | |
app.logger.info('payload empty') | |
return Response('What do you want?', content_type='text/plain') | |
response_tokens = bytearray() | |
def generate_and_log_tokens(user_request, generator): | |
global response_tokens, last_request_time | |
for token in llm.generate_tokens(generator): | |
if token == b'': # or (max_new_tokens is not None and i >= max_new_tokens): | |
last_request_time = datetime.now() | |
# log(json.dumps(user_request), response_tokens.decode("utf-8", errors="ignore")) | |
response_tokens = bytearray() | |
break | |
response_tokens.extend(token) | |
yield token | |
def generate_response(): | |
app.logger.info('generate_response') | |
with _lock: | |
if not llm.is_model_loaded(): | |
app.logger.info('model loading') | |
init_model() | |
data = request.get_json() | |
app.logger.info(data) | |
messages = data.get("messages", []) | |
preprompt = data.get("preprompt", "") | |
parameters = data.get("parameters", {}) | |
# Extract parameters from the request | |
p = { | |
'temperature': parameters.get("temperature", 0.01), | |
'truncate': parameters.get("truncate", 1000), | |
'max_new_tokens': parameters.get("max_new_tokens", 1024), | |
'top_p': parameters.get("top_p", 0.85), | |
'repetition_penalty': parameters.get("repetition_penalty", 1.2), | |
'top_k': parameters.get("top_k", 30), | |
'return_full_text': parameters.get("return_full_text", False) | |
} | |
generator = llm.create_chat_generator_for_saiga(messages=messages, parameters=p) | |
app.logger.info('Generator created') | |
# Use Response to stream tokens | |
return Response(generate_and_log_tokens(user_request='1', generator=generator), content_type='text/plain', status=200, direct_passthrough=True) | |
def init_model(): | |
llm.load_model(model_path=MODEL_PATH, context_size=CONTEXT_SIZE, enable_gpu=ENABLE_GPU, gpu_layer_number=GPU_LAYERS, n_gqa=N_GQA) | |
# Function to check if no requests were made in the last 5 minutes | |
def check_last_request_time(): | |
global last_request_time | |
current_time = datetime.now() | |
if (current_time - last_request_time).total_seconds() > 300: # 5 minutes in seconds | |
# Perform the action (e.g., set a variable) | |
llm.unload_model() | |
app.logger.info(f"Model unloaded at {current_time}") | |
else: | |
app.logger.info(f"No action needed at {current_time}") | |
if __name__ == "__main__": | |
scheduler = BackgroundScheduler() | |
scheduler.add_job(check_last_request_time, trigger='interval', minutes=1) | |
scheduler.start() | |
init_model() | |
app.run(host="0.0.0.0", port=7860, debug=True, threaded=True) |