import logging import colorsys import json import re import time import nltk import numpy as np from nltk import tokenize nltk.download('punkt') from google.oauth2 import service_account from google.cloud import texttospeech from typing import Dict, Optional, List import jwt import requests import streamlit as st from sentence_transformers import SentenceTransformer, util, CrossEncoder JWT_SECRET = st.secrets["api_secret"] JWT_ALGORITHM = st.secrets["api_algorithm"] INFERENCE_TOKEN = st.secrets["api_inference"] CONTEXT_API_URL = st.secrets["api_context"] LFQA_API_URL = st.secrets["api_lfqa"] headers = {"Authorization": f"Bearer {INFERENCE_TOKEN}"} API_URL = "https://api-inference.huggingface.co/models/askainet/bart_lfqa" API_URL_TTS = "https://api-inference.huggingface.co/models/espnet/kan-bayashi_ljspeech_joint_finetune_conformer_fastspeech2_hifigan" logger = logging.getLogger(__name__) def api_inference_lfqa(model_input: str): payload = { "inputs": model_input, "parameters": { "truncation": "longest_first", "min_length": st.session_state["min_length"], "max_length": st.session_state["max_length"], "do_sample": st.session_state["do_sample"], "early_stopping": st.session_state["early_stopping"], "num_beams": st.session_state["num_beams"], "temperature": st.session_state["temperature"], "top_k": None, "top_p": None, "no_repeat_ngram_size": 3, "num_return_sequences": 1 }, "options": { "wait_for_model": True } } data = json.dumps(payload) logger.debug(data) response = requests.request("POST", API_URL, headers=headers, data=data) return json.loads(response.content.decode("utf-8")) def inference_lfqa(model_input: str, header: dict): payload = { "model_input": model_input, "parameters": { "min_length": st.session_state["min_length"], "max_length": st.session_state["max_length"], "do_sample": st.session_state["do_sample"], "early_stopping": st.session_state["early_stopping"], "num_beams": st.session_state["num_beams"], "temperature": st.session_state["temperature"], "top_k": None, "top_p": None, "no_repeat_ngram_size": 3, "num_return_sequences": 1 } } data = json.dumps(payload) try: response = requests.request("POST", LFQA_API_URL, headers=header, data=data) if response.status_code == 200: json_response = response.content.decode("utf-8") result = json.loads(json_response) else: result = {"error": f"LFQA service unavailable, status code={response.status_code}"} except requests.exceptions.RequestException as e: result = {"error": e} return result def invoke_lfqa(service_backend: str, model_input: str, header: Optional[dict]): if "HuggingFace" == service_backend: inference_response = api_inference_lfqa(model_input) else: inference_response = inference_lfqa(model_input, header) return inference_response @st.cache(allow_output_mutation=True, show_spinner=False) def hf_tts(text: str): payload = { "inputs": text, "parameters": { "vocoder_tag": "str_or_none(none)", "threshold": 0.5, "minlenratio": 0.0, "maxlenratio": 10.0, "use_att_constraint": False, "backward_window": 1, "forward_window": 3, "speed_control_alpha": 1.0, "noise_scale": 0.333, "noise_scale_dur": 0.333 }, "options": { "wait_for_model": True } } data = json.dumps(payload) response = requests.request("POST", API_URL_TTS, headers=headers, data=data) return response.content @st.cache(allow_output_mutation=True, show_spinner=False) def google_tts(text: str, private_key_id: str, private_key: str, client_email: str): config = { "private_key_id": private_key_id, "private_key": f"-----BEGIN PRIVATE KEY-----\n{private_key}\n-----END PRIVATE KEY-----\n", "client_email": client_email, "token_uri": "https://oauth2.googleapis.com/token", } credentials = service_account.Credentials.from_service_account_info(config) client = texttospeech.TextToSpeechClient(credentials=credentials) synthesis_input = texttospeech.SynthesisInput(text=text) # Build the voice request, select the language code ("en-US") and the ssml # voice gender ("neutral") voice = texttospeech.VoiceSelectionParams(language_code="en-US", ssml_gender=texttospeech.SsmlVoiceGender.NEUTRAL) # Select the type of audio file you want returned audio_config = texttospeech.AudioConfig(audio_encoding=texttospeech.AudioEncoding.MP3) # Perform the text-to-speech request on the text input with the selected # voice parameters and audio file type response = client.synthesize_speech(input=synthesis_input, voice=voice, audio_config=audio_config) return response def request_context_passages(question, header): try: response = requests.request("GET", CONTEXT_API_URL + question, headers=header) if response.status_code == 200: json_response = response.content.decode("utf-8") result = json.loads(json_response) else: result = {"error": f"Context passage service unavailable, status code={response.status_code}"} except requests.exceptions.RequestException as e: result = {"error": e} return result @st.cache(allow_output_mutation=True, show_spinner=False) def get_sentence_transformer(): return SentenceTransformer('all-MiniLM-L6-v2') @st.cache(allow_output_mutation=True, show_spinner=False) def get_sentence_transformer_encoding(sentences): model = get_sentence_transformer() return model.encode([sentence for sentence in sentences], convert_to_tensor=True) def sign_jwt() -> Dict[str, str]: payload = { "expires": time.time() + 6000 } token = jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM) return token def extract_sentences_from_passages(passages): sentences = [] for idx, node in enumerate(passages): sentences.extend(tokenize.sent_tokenize(node["text"])) return sentences def similarity_color_picker(similarity: float): value = int(similarity * 75) rgb = colorsys.hsv_to_rgb(value / 300., 1.0, 1.0) return [round(255 * x) for x in rgb] def rgb_to_hex(rgb): return '%02x%02x%02x' % tuple(rgb) def similiarity_to_hex(similarity: float): return rgb_to_hex(similarity_color_picker(similarity)) def rerank(question: str, passages: List[str], include_rank: int = 4) -> List[str]: ce = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") question_passage_combinations = [[question, p["text"]] for p in passages] # Compute the similarity scores for these combinations similarity_scores = ce.predict(question_passage_combinations) # Sort the scores in decreasing order sim_ranking_idx = np.flip(np.argsort(similarity_scores)) return [passages[rank_idx] for rank_idx in sim_ranking_idx[:include_rank]] def answer_to_context_similarity(generated_answer, context_passages, topk=3): context_sentences = extract_sentences_from_passages(context_passages) context_sentences_e = get_sentence_transformer_encoding(context_sentences) answer_sentences = tokenize.sent_tokenize(generated_answer) answer_sentences_e = get_sentence_transformer_encoding(answer_sentences) search_result = util.semantic_search(answer_sentences_e, context_sentences_e, top_k=topk) result = [] for idx, r in enumerate(search_result): context = [] for idx_c in range(topk): context.append({"source": context_sentences[r[idx_c]["corpus_id"]], "score": r[idx_c]["score"]}) result.append({"answer": answer_sentences[idx], "context": context}) return result def post_process_answer(generated_answer): result = generated_answer # detect sentence boundaries regex pattern regex = r"([A-Z][a-z].*?[.:!?](?=$| [A-Z]))" answer_sentences = tokenize.sent_tokenize(generated_answer) # do we have truncated last sentence? if len(answer_sentences) > len(re.findall(regex, generated_answer)): drop_last_sentence = " ".join(s for s in answer_sentences[:-1]) result = drop_last_sentence return result.strip() def format_score(value: float, precision=2): return f"{value:.{precision}f}" @st.cache(allow_output_mutation=True, show_spinner=False) def get_answer(question: str): if not question: return {} resp: Dict[str, str] = {} if question and len(question.split()) > 3: header = {"Authorization": f"Bearer {sign_jwt()}"} context_passages = request_context_passages(question, header) if "error" in context_passages: resp = context_passages else: context_passages = rerank(question, context_passages) conditioned_context = "
" + "
".join([d["text"] for d in context_passages]) model_input = f'question: {question} context: {conditioned_context}' inference_response = invoke_lfqa(st.session_state["api_lfqa_selector"], model_input, header) if "error" in inference_response: resp = inference_response else: resp["context_passages"] = context_passages resp["answer"] = post_process_answer(inference_response[0]["generated_text"]) else: resp = {"error": f"A longer, more descriptive question will receive a better answer. '{question}' is too short."} return resp def app(): with open('style.css') as f: st.markdown(f"", unsafe_allow_html=True) footer = """
""" st.markdown(footer, unsafe_allow_html=True) st.title('Wikipedia Assistant') question = st.text_input( label='Ask Wikipedia an open-ended question below; for example, "Why do airplanes leave contrails in the sky?"') spinner = st.empty() if question !="": spinner.markdown( f"""Generating answer for: {question}