Aabbhishekk commited on
Commit
2ed8e0d
1 Parent(s): b01d6c1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -0
app.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from dotenv import load_dotenv
3
+ from langchain.chains import RetrievalQA
4
+ from langchain.chains.question_answering import load_qa_chain
5
+ from langchain_community.llms import HuggingFaceHub
6
+ from langchain.document_loaders import AssemblyAIAudioTranscriptLoader
7
+ from langchain.embeddings import HuggingFaceHubEmbeddings
8
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
9
+ from langchain.vectorstores import FAISS
10
+ from langchain.prompts import PromptTemplate
11
+ from tempfile import NamedTemporaryFile
12
+
13
+ # Load environment variables
14
+ load_dotenv()
15
+
16
+ # Function to create a prompt for retrieval QA chain
17
+ def create_qa_prompt() -> PromptTemplate:
18
+ template = """\n\nHuman: Use the following pieces of context to answer the question at the end. If the answer is not clear, say I DON'T KNOW
19
+ {context}
20
+ Question: {question}
21
+ \n\nAssistant:
22
+ Answer:"""
23
+
24
+ return PromptTemplate(template=template, input_variables=["context", "question"])
25
+
26
+ # Function to create documents from a list of URLs
27
+ def create_docs(urls_list):
28
+ documents = []
29
+ for url in urls_list:
30
+ st.write(f'Transcribing {url}')
31
+ documents.append(AssemblyAIAudioTranscriptLoader(file_path=url).load()[0])
32
+ return documents
33
+
34
+ # Function to create a Hugging Face embeddings model
35
+ def make_embedder():
36
+ model_name = "sentence-transformers/all-mpnet-base-v2"
37
+ model_kwargs = {'device': 'cpu'}
38
+ encode_kwargs = {'normalize_embeddings': False}
39
+ return HuggingFaceHubEmbeddings(
40
+ repo_id=model_name,
41
+ task="feature-extraction"
42
+ )
43
+
44
+ # Function to create a retrieval QA chain
45
+ def make_qa_chain():
46
+ llm = HuggingFaceHub(
47
+ repo_id="HuggingFaceH4/zephyr-7b-beta",
48
+ model_kwargs={
49
+ "max_new_tokens": 512,
50
+ "top_k": 30,
51
+ "temperature": 0.01,
52
+ "repetition_penalty": 1.5,
53
+ },
54
+ )
55
+ return llm
56
+ # return RetrievalQA.from_chain_type(
57
+ # llm,
58
+ # retriever=db.as_retriever(search_type="mmr", search_kwargs={'fetch_k': 3}),
59
+ # return_source_documents=True,
60
+ # chain_type_kwargs={
61
+ # "prompt": create_qa_prompt(),
62
+ # }
63
+ # )
64
+
65
+ # Streamlit UI
66
+ def main():
67
+ st.set_page_config(page_title="Audio Query Chatbot", page_icon=":microphone:", layout="wide")
68
+
69
+ # Left pane - Audio file upload
70
+ col1, col2 = st.columns([1, 2])
71
+
72
+ with col1:
73
+ st.header("Upload Audio File")
74
+ uploaded_file = st.file_uploader("Choose a WAV or MP3 file", type=["wav", "mp3"], key="audio_uploader")
75
+
76
+ if uploaded_file is not None:
77
+ with NamedTemporaryFile(suffix='.mp3') as temp:
78
+ temp.write(uploaded_file.getvalue())
79
+ temp.seek(0)
80
+ docs = create_docs([temp.name])
81
+
82
+ # text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
83
+ # texts = text_splitter.split_documents(docs)
84
+
85
+ # for text in texts:
86
+ # text.metadata = {"audio_url": text.metadata["audio_url"]}
87
+
88
+ st.success('Audio file transcribed successfully!')
89
+
90
+ # hf = make_embedder()
91
+ # db = FAISS.from_documents(texts, hf)
92
+
93
+ # qa_chain = make_qa_chain(db)
94
+
95
+ # Right pane - Chatbot Interface
96
+ with col2:
97
+ st.header("Chatbot Interface")
98
+
99
+ if uploaded_file is not None:
100
+ with st.form(key="form"):
101
+ user_input = st.text_input("Ask your question", key="user_input")
102
+
103
+ # Automatically submit the form on Enter key press
104
+ st.markdown("<div><br></div>", unsafe_allow_html=True) # Adds some space
105
+ st.markdown(
106
+ """<style>
107
+ #form input {margin-bottom: 15px;}
108
+ </style>""", unsafe_allow_html=True
109
+ )
110
+
111
+ submit = st.form_submit_button("Submit Question")
112
+
113
+ # Display the result once the form is submitted
114
+ if submit:
115
+ llm = make_qa_chain()
116
+ chain = load_qa_chain(llm, chain_type="stuff")
117
+ # docs = db.similarity_search(user_input)
118
+ result = chain.run(question=user_input,input_documents = docs)
119
+ # result = qa_chain.invoke(user_input)
120
+ # result = qa_chain({"query": user_input})
121
+ st.success("Query Result:")
122
+ st.write(f"User: {user_input}")
123
+ st.write(f"Assistant: {result}")
124
+
125
+ # st.subheader("Source Documents:")
126
+ # for idx, elt in enumerate(result['source_documents']):
127
+ # st.write(f"Source {idx + 1}:")
128
+ # st.write(f"Filepath: {elt.metadata['audio_url']}")
129
+ # st.write(f"Contents: {elt.page_content}")
130
+
131
+ if __name__ == "__main__":
132
+ main()