import spaces import os import gradio as gr from models import download_models from rag_backend import Backend from llama_cpp_agent import LlamaCppAgent, MessagesFormatterType from llama_cpp_agent.providers import LlamaCppPythonProvider from llama_cpp_agent.chat_history import BasicChatHistory from llama_cpp_agent.chat_history.messages import Roles import cv2 # get the models huggingface_token = os.environ.get('HF_TOKEN') download_models(huggingface_token) documents_paths = { 'blockchain': 'data/blockchain', 'metaverse': 'data/metaverse', 'payment': 'data/payment' } # initialize backend (not ideal as global variable...) backend = Backend() cv2.setNumThreads(1) @spaces.GPU(duration=20) def respond( message, history: list[tuple[str, str]], model, system_message, max_tokens, temperature, top_p, top_k, repeat_penalty, ): chat_template = MessagesFormatterType.GEMMA_2 print("HISTORY SO FAR ", history) matched_path = None words = message.lower() for key, path in documents_paths.items(): if len(history) == 1 and key in words: # check if the user mentions a path word only during second interaction (i.e history has only one entry) matched_path = path break print("matched_path", matched_path) if matched_path: # this case would only be true in second interaction original_message = history[0][0] print("** matched path!!") query_engine = backend.create_index_for_query_engine(matched_path) message = backend.generate_prompt(query_engine, original_message) gr.Info("Relevant context indexed from docs...") elif (not matched_path) and (len(history) > 1): print("Using context from storage db") query_engine = backend.load_index_for_query_engine() message = backend.generate_prompt(query_engine, message) gr.Info("Relevant context extracted from db...") # Load model only if it's not already loaded or if a new model is selected if backend.llm is None or backend.llm_model != model: try: backend.load_model(model) except Exception as e: return f"Error loading model: {str(e)}" provider = LlamaCppPythonProvider(backend.llm) agent = LlamaCppAgent( provider, system_prompt=f"{system_message}", predefined_messages_formatter_type=chat_template, debug_output=True ) settings = provider.get_provider_default_settings() settings.temperature = temperature settings.top_k = top_k settings.top_p = top_p settings.max_tokens = max_tokens settings.repeat_penalty = repeat_penalty settings.stream = True messages = BasicChatHistory() # add user and assistant messages to the history for msn in history: user = {'role': Roles.user, 'content': msn[0]} assistant = {'role': Roles.assistant, 'content': msn[1]} messages.add_message(user) messages.add_message(assistant) try: stream = agent.get_chat_response( message, llm_sampling_settings=settings, chat_history=messages, returns_streaming_generator=True, print_output=False ) outputs = "" for output in stream: outputs += output yield outputs except Exception as e: yield f"Error during response generation: {str(e)}" demo = gr.ChatInterface( fn=respond, css=""" .gradio-container { background-color: #B9D9EB; color: #003366; }""", additional_inputs=[ gr.Dropdown([ 'Meta-Llama-3.1-8B-Instruct-Q5_K_M.gguf', 'Mistral-Nemo-Instruct-2407-Q5_K_M.gguf', 'gemma-2-2b-it-Q6_K_L.gguf', 'openchat-3.6-8b-20240522-Q6_K.gguf', 'Llama-3-Groq-8B-Tool-Use-Q6_K.gguf', 'MiniCPM-V-2_6-Q6_K.gguf', 'llama-3.1-storm-8b-q5_k_m.gguf', 'orca-2-7b-patent-instruct-llama-2-q5_k_m.gguf' ], value="gemma-2-2b-it-Q6_K_L.gguf", label="Model" ), gr.Textbox(value="""Solamente all'inizio, presentati come Odi, un assistente ricercatore italiano creato dagli Osservatori del Politecnico di Milano e specializzato nel fornire risposte precise e pertinenti solo ad argomenti di innovazione digitale. Solo nella tua prima risposta, chiedi all'utente di indicare a quale di queste tre sezioni degli Osservatori si riferisce la sua domanda: 'Blockchain', 'Payment' o 'Metaverse'. Per le risposte successive, utilizza la cronologia della chat o il contesto fornito per aiutare l'utente a ottenere una risposta accurata. Non rispondere mai a domande che non sono pertinenti a questi argomenti.""", label="System message"), gr.Slider(minimum=1, maximum=4096, value=3048, step=1, label="Max tokens"), gr.Slider(minimum=0.1, maximum=4.0, value=1.2, step=0.1, label="Temperature"), gr.Slider( minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p", ), gr.Slider( minimum=0, maximum=100, value=30, step=1, label="Top-k", ), gr.Slider( minimum=0.0, maximum=2.0, value=1.1, step=0.1, label="Repetition penalty", ), ], retry_btn="Riprova", undo_btn="Annulla", clear_btn="Pulisci", submit_btn="Invia", title="Odi, l'assistente ricercatore degli Osservatori", chatbot=gr.Chatbot( scale=1, likeable=False, show_copy_button=True ), examples=[["Ciao, in cosa puoi aiutarmi?"],["Quanto vale il mercato italiano?"], ["Per favore dammi informazioni sugli ambiti applicativi"], ["Svelami una buona ricetta milanese"] ], cache_examples=False, ) if __name__ == "__main__": demo.launch()