NCTCMumbai's picture
Update backend/query_llm.py
a095bcc verified
raw
history blame
10.9 kB
import openai
import gradio as gr
from os import getenv
from typing import Any, Dict, Generator, List
from huggingface_hub import InferenceClient
from transformers import AutoTokenizer
import google.generativeai as genai
import os
import PIL.Image
import gradio as gr
#from gradio_multimodalchatbot import MultimodalChatbot
from gradio.data_classes import FileData
#tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1")
# temperature = 0.2
# #top_p = 0.6
# repetition_penalty = 1.0
temperature = 0.5
top_p = 0.7
repetition_penalty = 1.2
# Fetch an environment variable.
GOOGLE_API_KEY = os.environ.get('GOOGLE_API_KEY')
genai.configure(api_key=GOOGLE_API_KEY)
OPENAI_KEY = getenv("OPENAI_API_KEY")
HF_TOKEN = getenv("HUGGING_FACE_HUB_TOKEN")
# hf_client = InferenceClient(
# "mistralai/Mistral-7B-Instruct-v0.1",
# token=HF_TOKEN
# )
hf_client = InferenceClient(
"mistralai/Mixtral-8x7B-Instruct-v0.1",
token=HF_TOKEN
)
def format_prompt(message: str, api_kind: str):
"""
Formats the given message using a chat template.
Args:
message (str): The user message to be formatted.
Returns:
str: Formatted message after applying the chat template.
"""
# Create a list of message dictionaries with role and content
messages: List[Dict[str, Any]] = [{'role': 'user', 'content': message}]
if api_kind == "openai":
return messages
elif api_kind == "hf":
return tokenizer.apply_chat_template(messages, tokenize=False)
elif api_kind:
raise ValueError("API is not supported")
def generate_hf(prompt: str, history: str, temperature: float = 0.9, max_new_tokens: int = 4000,
top_p: float = 0.95, repetition_penalty: float = 1.0) -> Generator[str, None, str]:
"""
Generate a sequence of tokens based on a given prompt and history using Mistral client.
Args:
prompt (str): The initial prompt for the text generation.
history (str): Context or history for the text generation.
temperature (float, optional): The softmax temperature for sampling. Defaults to 0.9.
max_new_tokens (int, optional): Maximum number of tokens to be generated. Defaults to 256.
top_p (float, optional): Nucleus sampling probability. Defaults to 0.95.
repetition_penalty (float, optional): Penalty for repeated tokens. Defaults to 1.0.
Returns:
Generator[str, None, str]: A generator yielding chunks of generated text.
Returns a final string if an error occurs.
"""
temperature = max(float(temperature), 1e-2) # Ensure temperature isn't too low
top_p = float(top_p)
generate_kwargs = {
'temperature': temperature,
'max_new_tokens': max_new_tokens,
'top_p': top_p,
'repetition_penalty': repetition_penalty,
'do_sample': True,
'seed': 42,
}
formatted_prompt = format_prompt(prompt, "hf")
try:
stream = hf_client.text_generation(formatted_prompt, **generate_kwargs,
stream=True, details=True, return_full_text=False)
output = ""
for response in stream:
output += response.token.text
yield output
except Exception as e:
if "Too Many Requests" in str(e):
print("ERROR: Too many requests on Mistral client")
gr.Warning("Unfortunately Mistral is unable to process")
return "Unfortunately, I am not able to process your request now."
elif "Authorization header is invalid" in str(e):
print("Authetification error:", str(e))
gr.Warning("Authentication error: HF token was either not provided or incorrect")
return "Authentication error"
else:
print("Unhandled Exception:", str(e))
gr.Warning("Unfortunately Mistral is unable to process")
return "I do not know what happened, but I couldn't understand you."
def generate_openai(prompt: str, history: str, temperature: float = 0.9, max_new_tokens: int = 256,
top_p: float = 0.95, repetition_penalty: float = 1.0) -> Generator[str, None, str]:
"""
Generate a sequence of tokens based on a given prompt and history using Mistral client.
Args:
prompt (str): The initial prompt for the text generation.
history (str): Context or history for the text generation.
temperature (float, optional): The softmax temperature for sampling. Defaults to 0.9.
max_new_tokens (int, optional): Maximum number of tokens to be generated. Defaults to 256.
top_p (float, optional): Nucleus sampling probability. Defaults to 0.95.
repetition_penalty (float, optional): Penalty for repeated tokens. Defaults to 1.0.
Returns:
Generator[str, None, str]: A generator yielding chunks of generated text.
Returns a final string if an error occurs.
"""
temperature = max(float(temperature), 1e-2) # Ensure temperature isn't too low
top_p = float(top_p)
generate_kwargs = {
'temperature': temperature,
'max_tokens': max_new_tokens,
'top_p': top_p,
'frequency_penalty': max(-2., min(repetition_penalty, 2.)),
}
formatted_prompt = format_prompt(prompt, "hf")
try:
stream = openai.ChatCompletion.create(model="gpt-3.5-turbo-0301",
messages=formatted_prompt,
**generate_kwargs,
stream=True)
output = ""
for chunk in stream:
output += chunk.choices[0].delta.get("content", "")
yield output
except Exception as e:
if "Too Many Requests" in str(e):
print("ERROR: Too many requests on OpenAI client")
gr.Warning("Unfortunately OpenAI is unable to process")
return "Unfortunately, I am not able to process your request now."
elif "You didn't provide an API key" in str(e):
print("Authetification error:", str(e))
gr.Warning("Authentication error: OpenAI key was either not provided or incorrect")
return "Authentication error"
else:
print("Unhandled Exception:", str(e))
gr.Warning("Unfortunately OpenAI is unable to process")
return "I do not know what happened, but I couldn't understand you."
def generate_gemini(prompt: str, history: str, temperature: float = 0.9, max_new_tokens: int = 4000,
top_p: float = 0.95, repetition_penalty: float = 1.0):
# For better security practices, retrieve sensitive information like API keys from environment variables.
# Initialize genai models
model = genai.GenerativeModel('gemini-pro')
api_key = os.environ.get("GOOGEL_API_KEY")
genai.configure(api_key=api_key)
#model = genai.GenerativeModel('gemini-pro')
#chat = model.start_chat(history=[])
candidate_count=1
max_output_tokens=max_new_tokens
temperature=temperature
top_p=top_p
formatted_prompt = format_prompt(prompt, "gemini")
try:
stream = model.generate_content(formatted_prompt,generation_config=genai.GenerationConfig(temperature=temperature,candidate_count=1 ,max_output_tokens=max_new_tokens,top_p=top_p),
stream=True)
output = ""
for response in stream:
output += response.text
yield output
except Exception as e:
if "Too Many Requests" in str(e):
print("ERROR: Too many requests on Mistral client")
gr.Warning("Unfortunately Mistral is unable to process")
return "Unfortunately, I am not able to process your request now."
elif "Authorization header is invalid" in str(e):
print("Authetification error:", str(e))
gr.Warning("Authentication error: HF token was either not provided or incorrect")
return "Authentication error"
else:
print("Unhandled Exception:", str(e))
gr.Warning("Unfortunately Mistral is unable to process")
return "I do not know what happened, but I couldn't understand you."
# def gemini(input, file, chatbot=[]):
# """
# Function to handle gemini model and gemini vision model interactions.
# Parameters:
# input (str): The input text.
# file (File): An optional file object for image processing.
# chatbot (list): A list to keep track of chatbot interactions.
# Returns:
# tuple: Updated chatbot interaction list, an empty string, and None.
# """
# messages = []
# print(chatbot)
# # Process previous chatbot messages if present
# if len(chatbot) != 0:
# for messages_dict in chatbot:
# user_text = messages_dict[0]['text']
# bot_text = messages_dict[1]['text']
# messages.extend([
# {'role': 'user', 'parts': [user_text]},
# {'role': 'model', 'parts': [bot_text]}
# ])
# messages.append({'role': 'user', 'parts': [input]})
# else:
# messages.append({'role': 'user', 'parts': [input]})
# try:
# response = model.generate_content(messages)
# gemini_resp = response.text
# # Construct list of messages in the required format
# user_msg = {"text": input, "files": []}
# bot_msg = {"text": gemini_resp, "files": []}
# chatbot.append([user_msg, bot_msg])
# except Exception as e:
# # Handling exceptions and raising error to the modal
# print(f"An error occurred: {e}")
# raise gr.Error(e)
# return chatbot, "", None
# # Define the Gradio Blocks interface
# with gr.Blocks() as demo:
# # Add a centered header using HTML
# gr.HTML("<center><h1>Gemini Chat PRO API</h1></center>")
# # Initialize the MultimodalChatbot component
# multi = MultimodalChatbot(value=[], height=800)
# with gr.Row():
# # Textbox for user input with increased scale for better visibility
# tb = gr.Textbox(scale=4, placeholder='Input text and press Enter')
# # Define the behavior on text submission
# tb.submit(gemini, [tb, multi], [multi, tb])
# # Launch the demo with a queue to handle multiple users
# demo.queue().launch()