Spaces:
Sleeping
Sleeping
import threading | |
import http.server | |
import socketserver | |
import os | |
import yaml | |
from flask import Flask, request, jsonify | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
import torch | |
import gradio as gr | |
from utils.upload_file import UploadFile | |
from utils.chatbot import ChatBot | |
from utils.ui_settings import UISettings | |
from utils.load_config import LoadConfig | |
from pyprojroot import here | |
# Load the app config | |
with open(here("configs/app_config.yml")) as cfg: | |
app_config = yaml.load(cfg, Loader=yaml.FullLoader) | |
PORT = app_config["serve"]["port"] | |
DIRECTORY1 = app_config["directories"]["data_directory"] | |
DIRECTORY2 = app_config["directories"]["data_directory_2"] | |
# ================================ | |
# Part 1: Reference Serve Code | |
# ================================ | |
class MultiDirectoryHTTPRequestHandler(http.server.SimpleHTTPRequestHandler): | |
"""Serve files from multiple directories.""" | |
def translate_path(self, path): | |
parts = path.split('/', 2) | |
if len(parts) > 1: | |
first_directory = parts[1] | |
if first_directory == os.path.basename(DIRECTORY1): | |
path = os.path.join(DIRECTORY1, *parts[2:]) | |
elif first_directory == os.path.basename(DIRECTORY2): | |
path = os.path.join(DIRECTORY2, *parts[2:]) | |
else: | |
file_path1 = os.path.join(DIRECTORY1, first_directory) | |
file_path2 = os.path.join(DIRECTORY2, first_directory) | |
if os.path.isfile(file_path1): | |
return file_path1 | |
elif os.path.isfile(file_path2): | |
return file_path2 | |
return super().translate_path(path) | |
def start_reference_server(): | |
with socketserver.TCPServer(("", PORT), MultiDirectoryHTTPRequestHandler) as httpd: | |
print(f"Serving at port {PORT}") | |
httpd.serve_forever() | |
# ================================ | |
# Part 2: LLM Serve Code | |
# ================================ | |
APPCFG = LoadConfig() | |
app = Flask(__name__) | |
# Load the LLM and tokenizer | |
tokenizer = AutoTokenizer.from_pretrained( | |
APPCFG.llm_engine, token=APPCFG.gemma_token, device=APPCFG.device) | |
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path="BioMistral/BioMistral-7B", | |
token=APPCFG.gemma_token, | |
torch_dtype=torch.float16, | |
device_map=APPCFG.device) | |
app_pipeline = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer | |
) | |
def generate_text(): | |
data = request.json | |
prompt = data.get("prompt", "") | |
max_new_tokens = data.get("max_new_tokens", 1000) | |
do_sample = data.get("do_sample", True) | |
temperature = data.get("temperature", 0.1) | |
top_k = data.get("top_k", 50) | |
top_p = data.get("top_p", 0.95) | |
tokenized_prompt = app_pipeline.tokenizer.apply_chat_template( | |
prompt, tokenize=False, add_generation_prompt=True) | |
outputs = app_pipeline( | |
tokenized_prompt, | |
max_new_tokens=max_new_tokens, | |
do_sample=do_sample, | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p | |
) | |
return jsonify({"response": outputs[0]["generated_text"][len(tokenized_prompt):]}) | |
def start_llm_server(): | |
app.run(debug=False, port=8888) | |
# ================================ | |
# Part 3: Gradio Chatbot Code | |
# ================================ | |
def start_gradio_app(): | |
with gr.Blocks() as demo: | |
with gr.Tabs(): | |
with gr.TabItem("Med-App"): | |
# First row | |
with gr.Row() as row_one: | |
with gr.Column(visible=False) as reference_bar: | |
ref_output = gr.Markdown() | |
with gr.Column() as chatbot_output: | |
chatbot = gr.Chatbot( | |
[], elem_id="chatbot", bubble_full_width=False, height=500, | |
avatar_images=("images/test.png", "images/Gemma-logo.png") | |
) | |
chatbot.like(UISettings.feedback, None, None) | |
# Second row | |
with gr.Row(): | |
input_txt = gr.Textbox( | |
lines=4, scale=8, placeholder="Enter text and press enter, or upload PDF files" | |
) | |
# Third row | |
with gr.Row() as row_two: | |
text_submit_btn = gr.Button(value="Submit text") | |
btn_toggle_sidebar = gr.Button(value="References") | |
upload_btn = gr.UploadButton( | |
"π Upload PDF or doc files", file_types=['.pdf', '.doc'], file_count="multiple" | |
) | |
clear_button = gr.ClearButton([input_txt, chatbot]) | |
rag_with_dropdown = gr.Dropdown( | |
label="RAG with", choices=["Preprocessed doc", "Upload doc: Process for RAG"], value="Preprocessed doc" | |
) | |
# Fourth row | |
with gr.Row() as row_four: | |
temperature_bar = gr.Slider( | |
minimum=0.1, maximum=1, value=0.1, step=0.1, label="Temperature", | |
info="Increasing the temperature will make the model answer more creatively." | |
) | |
top_k = gr.Slider( | |
minimum=0.0, maximum=100.0, step=1, label="top_k", value=50, | |
info="A lower value (e.g. 10) will result in more conservative answers." | |
) | |
top_p = gr.Slider( | |
minimum=0.0, maximum=1.0, step=0.01, label="top_p", value=0.95, | |
info="A lower value will generate more focused and conservative text." | |
) | |
# Process uploaded files and text | |
file_msg = upload_btn.upload( | |
fn=UploadFile.process_uploaded_files, inputs=[upload_btn, chatbot, rag_with_dropdown], | |
outputs=[input_txt, chatbot], queue=False | |
) | |
txt_msg = input_txt.submit( | |
fn=ChatBot.respond, inputs=[chatbot, input_txt, rag_with_dropdown, temperature_bar, top_k, top_p], | |
outputs=[input_txt, chatbot, ref_output], queue=False | |
).then(lambda: gr.Textbox(interactive=True), None, [input_txt], queue=False) | |
text_submit_btn.click( | |
fn=ChatBot.respond, inputs=[chatbot, input_txt, rag_with_dropdown, temperature_bar, top_k, top_p], | |
outputs=[input_txt, chatbot, ref_output], queue=False | |
).then(lambda: gr.Textbox(interactive=True), None, [input_txt], queue=False) | |
demo.launch() | |
# ================================ | |
# Main: Running all services concurrently | |
# ================================ | |
if __name__ == "__main__": | |
# Start all services in separate threads | |
reference_server_thread = threading.Thread(target=start_reference_server) | |
llm_server_thread = threading.Thread(target=start_llm_server) | |
gradio_app_thread = threading.Thread(target=start_gradio_app) | |
reference_server_thread.start() | |
llm_server_thread.start() | |
gradio_app_thread.start() | |
# Keep the main thread alive | |
reference_server_thread.join() | |
llm_server_thread.join() | |
gradio_app_thread.join() | |