Aabbhishekk's picture
Create app.py
2ed8e0d verified
raw
history blame
4.84 kB
import streamlit as st
from dotenv import load_dotenv
from langchain.chains import RetrievalQA
from langchain.chains.question_answering import load_qa_chain
from langchain_community.llms import HuggingFaceHub
from langchain.document_loaders import AssemblyAIAudioTranscriptLoader
from langchain.embeddings import HuggingFaceHubEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.prompts import PromptTemplate
from tempfile import NamedTemporaryFile
# Load environment variables
load_dotenv()
# Function to create a prompt for retrieval QA chain
def create_qa_prompt() -> PromptTemplate:
template = """\n\nHuman: Use the following pieces of context to answer the question at the end. If the answer is not clear, say I DON'T KNOW
{context}
Question: {question}
\n\nAssistant:
Answer:"""
return PromptTemplate(template=template, input_variables=["context", "question"])
# Function to create documents from a list of URLs
def create_docs(urls_list):
documents = []
for url in urls_list:
st.write(f'Transcribing {url}')
documents.append(AssemblyAIAudioTranscriptLoader(file_path=url).load()[0])
return documents
# Function to create a Hugging Face embeddings model
def make_embedder():
model_name = "sentence-transformers/all-mpnet-base-v2"
model_kwargs = {'device': 'cpu'}
encode_kwargs = {'normalize_embeddings': False}
return HuggingFaceHubEmbeddings(
repo_id=model_name,
task="feature-extraction"
)
# Function to create a retrieval QA chain
def make_qa_chain():
llm = HuggingFaceHub(
repo_id="HuggingFaceH4/zephyr-7b-beta",
model_kwargs={
"max_new_tokens": 512,
"top_k": 30,
"temperature": 0.01,
"repetition_penalty": 1.5,
},
)
return llm
# return RetrievalQA.from_chain_type(
# llm,
# retriever=db.as_retriever(search_type="mmr", search_kwargs={'fetch_k': 3}),
# return_source_documents=True,
# chain_type_kwargs={
# "prompt": create_qa_prompt(),
# }
# )
# Streamlit UI
def main():
st.set_page_config(page_title="Audio Query Chatbot", page_icon=":microphone:", layout="wide")
# Left pane - Audio file upload
col1, col2 = st.columns([1, 2])
with col1:
st.header("Upload Audio File")
uploaded_file = st.file_uploader("Choose a WAV or MP3 file", type=["wav", "mp3"], key="audio_uploader")
if uploaded_file is not None:
with NamedTemporaryFile(suffix='.mp3') as temp:
temp.write(uploaded_file.getvalue())
temp.seek(0)
docs = create_docs([temp.name])
# text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
# texts = text_splitter.split_documents(docs)
# for text in texts:
# text.metadata = {"audio_url": text.metadata["audio_url"]}
st.success('Audio file transcribed successfully!')
# hf = make_embedder()
# db = FAISS.from_documents(texts, hf)
# qa_chain = make_qa_chain(db)
# Right pane - Chatbot Interface
with col2:
st.header("Chatbot Interface")
if uploaded_file is not None:
with st.form(key="form"):
user_input = st.text_input("Ask your question", key="user_input")
# Automatically submit the form on Enter key press
st.markdown("<div><br></div>", unsafe_allow_html=True) # Adds some space
st.markdown(
"""<style>
#form input {margin-bottom: 15px;}
</style>""", unsafe_allow_html=True
)
submit = st.form_submit_button("Submit Question")
# Display the result once the form is submitted
if submit:
llm = make_qa_chain()
chain = load_qa_chain(llm, chain_type="stuff")
# docs = db.similarity_search(user_input)
result = chain.run(question=user_input,input_documents = docs)
# result = qa_chain.invoke(user_input)
# result = qa_chain({"query": user_input})
st.success("Query Result:")
st.write(f"User: {user_input}")
st.write(f"Assistant: {result}")
# st.subheader("Source Documents:")
# for idx, elt in enumerate(result['source_documents']):
# st.write(f"Source {idx + 1}:")
# st.write(f"Filepath: {elt.metadata['audio_url']}")
# st.write(f"Contents: {elt.page_content}")
if __name__ == "__main__":
main()