Spaces:
Sleeping
Sleeping
import streamlit as st | |
from dotenv import load_dotenv | |
from langchain.chains import RetrievalQA | |
from langchain.chains.question_answering import load_qa_chain | |
from langchain_community.llms import HuggingFaceHub | |
from langchain.document_loaders import AssemblyAIAudioTranscriptLoader | |
from langchain.embeddings import HuggingFaceHubEmbeddings | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.vectorstores import FAISS | |
from langchain.prompts import PromptTemplate | |
from tempfile import NamedTemporaryFile | |
# Load environment variables | |
load_dotenv() | |
# Function to create a prompt for retrieval QA chain | |
def create_qa_prompt() -> PromptTemplate: | |
template = """\n\nHuman: Use the following pieces of context to answer the question at the end. If the answer is not clear, say I DON'T KNOW | |
{context} | |
Question: {question} | |
\n\nAssistant: | |
Answer:""" | |
return PromptTemplate(template=template, input_variables=["context", "question"]) | |
# Function to create documents from a list of URLs | |
def create_docs(urls_list): | |
documents = [] | |
for url in urls_list: | |
st.write(f'Transcribing {url}') | |
documents.append(AssemblyAIAudioTranscriptLoader(file_path=url).load()[0]) | |
return documents | |
# Function to create a Hugging Face embeddings model | |
def make_embedder(): | |
model_name = "sentence-transformers/all-mpnet-base-v2" | |
model_kwargs = {'device': 'cpu'} | |
encode_kwargs = {'normalize_embeddings': False} | |
return HuggingFaceHubEmbeddings( | |
repo_id=model_name, | |
task="feature-extraction" | |
) | |
# Function to create a retrieval QA chain | |
def make_qa_chain(): | |
llm = HuggingFaceHub( | |
repo_id="HuggingFaceH4/zephyr-7b-beta", | |
model_kwargs={ | |
"max_new_tokens": 512, | |
"top_k": 30, | |
"temperature": 0.01, | |
"repetition_penalty": 1.5, | |
}, | |
) | |
return llm | |
# return RetrievalQA.from_chain_type( | |
# llm, | |
# retriever=db.as_retriever(search_type="mmr", search_kwargs={'fetch_k': 3}), | |
# return_source_documents=True, | |
# chain_type_kwargs={ | |
# "prompt": create_qa_prompt(), | |
# } | |
# ) | |
# Streamlit UI | |
def main(): | |
st.set_page_config(page_title="Audio Query Chatbot", page_icon=":microphone:", layout="wide") | |
# Left pane - Audio file upload | |
col1, col2 = st.columns([1, 2]) | |
with col1: | |
st.header("Upload Audio File") | |
uploaded_file = st.file_uploader("Choose a WAV or MP3 file", type=["wav", "mp3"], key="audio_uploader") | |
if uploaded_file is not None: | |
with NamedTemporaryFile(suffix='.mp3') as temp: | |
temp.write(uploaded_file.getvalue()) | |
temp.seek(0) | |
docs = create_docs([temp.name]) | |
# text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0) | |
# texts = text_splitter.split_documents(docs) | |
# for text in texts: | |
# text.metadata = {"audio_url": text.metadata["audio_url"]} | |
st.success('Audio file transcribed successfully!') | |
# hf = make_embedder() | |
# db = FAISS.from_documents(texts, hf) | |
# qa_chain = make_qa_chain(db) | |
# Right pane - Chatbot Interface | |
with col2: | |
st.header("Chatbot Interface") | |
if uploaded_file is not None: | |
with st.form(key="form"): | |
user_input = st.text_input("Ask your question", key="user_input") | |
# Automatically submit the form on Enter key press | |
st.markdown("<div><br></div>", unsafe_allow_html=True) # Adds some space | |
st.markdown( | |
"""<style> | |
#form input {margin-bottom: 15px;} | |
</style>""", unsafe_allow_html=True | |
) | |
submit = st.form_submit_button("Submit Question") | |
# Display the result once the form is submitted | |
if submit: | |
llm = make_qa_chain() | |
chain = load_qa_chain(llm, chain_type="stuff") | |
# docs = db.similarity_search(user_input) | |
result = chain.run(question=user_input,input_documents = docs) | |
# result = qa_chain.invoke(user_input) | |
# result = qa_chain({"query": user_input}) | |
st.success("Query Result:") | |
st.write(f"User: {user_input}") | |
st.write(f"Assistant: {result}") | |
# st.subheader("Source Documents:") | |
# for idx, elt in enumerate(result['source_documents']): | |
# st.write(f"Source {idx + 1}:") | |
# st.write(f"Filepath: {elt.metadata['audio_url']}") | |
# st.write(f"Contents: {elt.page_content}") | |
if __name__ == "__main__": | |
main() | |