Spaces:
Sleeping
Sleeping
import json | |
import os | |
import gradio as gr | |
from huggingface_hub import InferenceClient | |
from app_modules.utils import calc_bleu_rouge_scores, detect_repetitions | |
from dotenv import find_dotenv, load_dotenv | |
found_dotenv = find_dotenv(".env") | |
HF_RP = os.getenv("HF_RP", "1.2") | |
repetition_penalty = float(HF_RP) | |
print(f" repetition_penalty: {repetition_penalty}") | |
questions_file_path = ( | |
os.getenv("QUESTIONS_FILE_PATH") or "./data/datasets/ms_macro.json" | |
) | |
questions = json.loads(open(questions_file_path).read()) | |
examples = [[question["question"].strip()] for question in questions] | |
print(f"Loaded {len(examples)} examples") | |
qa_system_prompt = "Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer." | |
""" | |
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference | |
""" | |
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta") | |
def chat( | |
message, | |
history: list[tuple[str, str]], | |
system_message, | |
temperature=0, | |
repetition_penalty=1.1, | |
do_sample=True, | |
max_tokens=1024, | |
top_p=0.95, | |
): | |
print("repetition_penalty:", repetition_penalty) | |
chat = [] | |
for item in history: | |
chat.append({"role": "user", "content": item[0]}) | |
if item[1] is not None: | |
chat.append({"role": "assistant", "content": item[1]}) | |
index = -1 | |
if [message] in examples: | |
index = examples.index([message]) | |
message = f"{qa_system_prompt}\n\n{questions[index]['context']}\n\nQuestion: {message}" | |
print("RAG prompt:", message) | |
chat.append({"role": "user", "content": message}) | |
messages = [{"role": "system", "content": system_message}] | |
messages.append({"role": "user", "content": message}) | |
partial_text = "" | |
# huggingface_hub.utils._errors.HfHubHTTPError: 422 Client Error: Unprocessable Entity for url: https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-beta (Request ID: NZamtWmdoSg3flfgRKT0e) | |
# Make sure 'text-generation' task is supported by the model. | |
# for message in client.text_generation( | |
# messages, | |
# stream=True, | |
# temperature=temperature, | |
# top_p=top_p, | |
# repetition_penalty=repetition_penalty, | |
# ): | |
# https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-beta | |
# { | |
# "id": "HuggingFaceH4/zephyr-7b-beta", | |
# "sha": "b70e0c9a2d9e14bd1e812d3c398e5f313e93b473", | |
# "pipeline_tag": "text-generation", | |
# "library_name": "transformers", | |
# "private": false, | |
# "gated": false, | |
# "siblings": [], | |
# "safetensors": { | |
# "parameters": { | |
# "BF16": 7241732096 | |
# } | |
# }, | |
# "cardData": { | |
# "tags": [ | |
# "generated_from_trainer" | |
# ], | |
# "base_model": "mistralai/Mistral-7B-v0.1" | |
# } | |
# } | |
for message in client.chat_completion( | |
messages, | |
max_tokens=max_tokens, | |
stream=True, | |
temperature=temperature, | |
top_p=top_p, | |
): | |
new_text = message.choices[0].delta.content | |
partial_text += new_text | |
yield partial_text | |
answer = partial_text | |
(whitespace_score, repetition_score, total_repetitions) = detect_repetitions(answer) | |
partial_text += "\n\nRepetition Metrics:\n" | |
partial_text += f"1. Whitespace Score: {whitespace_score:.3f}\n" | |
partial_text += f"1. Repetition Score: {repetition_score:.3f}\n" | |
partial_text += f"1. Total Repetitions: {total_repetitions:.3f}\n" | |
if index >= 0: # RAG | |
key = ( | |
"wellFormedAnswers" | |
if "wellFormedAnswers" in questions[index] | |
else "answers" | |
) | |
scores = calc_bleu_rouge_scores([answer], [questions[index][key]], debug=True) | |
partial_text += "\n\n Performance Metrics:\n" | |
partial_text += f'1. BLEU-1: {scores["bleu_scores"]["bleu"]:.3f}\n' | |
partial_text += f'1. RougeL: {scores["rouge_scores"]["rougeL"]:.3f}\n' | |
yield partial_text | |
demo = gr.ChatInterface( | |
fn=chat, | |
examples=examples, | |
cache_examples=False, | |
additional_inputs_accordion=gr.Accordion( | |
label="⚙️ Parameters", open=False, render=False | |
), | |
additional_inputs=[ | |
gr.Textbox(value="You are a friendly Chatbot.", label="System message"), | |
gr.Slider( | |
minimum=0, maximum=1, step=0.1, value=0, label="Temperature", render=False | |
), | |
gr.Slider( | |
minimum=1.0, | |
maximum=1.5, | |
step=0.1, | |
value=repetition_penalty, | |
label="Repetition Penalty", | |
render=False, | |
), | |
gr.Checkbox(label="Sampling", value=True), | |
gr.Slider( | |
minimum=128, | |
maximum=4096, | |
step=1, | |
value=512, | |
label="Max new tokens", | |
render=False, | |
), | |
gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.95, | |
step=0.05, | |
label="Top-p (nucleus sampling)", | |
), | |
], | |
) | |
demo.launch() | |