Samarth991 commited on
Commit
2faf743
1 Parent(s): 19b1878

application to run llama-7b on Audio files

Browse files
Files changed (4) hide show
  1. app.py +177 -0
  2. llm_ops.py +21 -0
  3. requirements.txt +12 -0
  4. whisper_app.py +69 -0
app.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import gradio as gr
3
+ import logging
4
+ from langchain.text_splitter import CharacterTextSplitter
5
+ from langchain.embeddings import SentenceTransformerEmbeddings
6
+ from langchain.vectorstores import FAISS
7
+ from langchain.chains import RetrievalQA
8
+ from langchain.prompts import PromptTemplate
9
+ from langchain.docstore.document import Document
10
+ from whisper_app import WHISPERModel
11
+ import llm_ops
12
+
13
+ FILE_EXT = ['wav','mp3']
14
+ MAX_NEW_TOKENS = 4096
15
+ DEFAULT_MAX_NEW_TOKENS = 1024
16
+ DEFAULT_TEMPERATURE = 0.1
17
+
18
+ def create_logger():
19
+ formatter = logging.Formatter('%(asctime)s:%(levelname)s:- %(message)s')
20
+ console_handler = logging.StreamHandler()
21
+ console_handler.setLevel(logging.INFO)
22
+ console_handler.setFormatter(formatter)
23
+
24
+ logger = logging.getLogger("APT_Realignment")
25
+ logger.setLevel(logging.INFO)
26
+
27
+ if not logger.hasHandlers():
28
+ logger.addHandler(console_handler)
29
+ logger.propagate = False
30
+ return logger
31
+
32
+
33
+ def create_prompt():
34
+ prompt_template = """Asnwer the questions regarding the content in the Audio .
35
+ Use the following context to answer.
36
+ If you don't know the answer, just say I don't know.
37
+
38
+ {context}
39
+
40
+ Question: {question}
41
+ Answer :"""
42
+ prompt = PromptTemplate(
43
+ template=prompt_template, input_variables=["context", "question"]
44
+ )
45
+ return prompt
46
+
47
+
48
+ logger = create_logger()
49
+
50
+ def process_documents(documents,data_chunk=1500,chunk_overlap=100):
51
+ text_splitter = CharacterTextSplitter(chunk_size=data_chunk, chunk_overlap=chunk_overlap,separator='\n')
52
+ texts = text_splitter.split_documents(documents)
53
+ return texts
54
+
55
+ def audio_processor(wav_file,API_key,wav_model='small',llm='HuggingFace',temperature=0.1,max_tokens=4096):
56
+ device='cpu'
57
+ logger.info("Loading Whsiper Model || Model size:{}".format(wav_model))
58
+ whisper = WHISPERModel(model_name=wav_model,device=device)
59
+ text_info = whisper.speech_to_text(audio_path=wav_file)
60
+
61
+ metadata = {"source": f"{wav_file}","duration":text_info['duration'],"language":text_info['language']}
62
+ document = [Document(page_content=text_info['text'], metadata=metadata)]
63
+ logger.info("Document",document)
64
+ logging.info("Loading General Text Embeddings (GTE) model{}".format('thenlper/gte-large'))
65
+ embedding_model = SentenceTransformerEmbeddings(model_name='thenlper/gte-large',model_kwargs={"device": device})
66
+ texts = process_documents(documents=document)
67
+ global vector_db
68
+ vector_db = FAISS.from_documents(documents=texts, embedding= embedding_model)
69
+ global qa
70
+ if llm == 'HuggingFace':
71
+ chat = llm_ops.get_hugging_face_model(
72
+ model_id="meta-llama/Llama-2-7b",
73
+ API_key=API_key,
74
+ temperature=temperature,
75
+ max_tokens=max_tokens
76
+ )
77
+ else:
78
+ chat = llm_ops.get_openai_chat_model(API_key=API_key)
79
+
80
+ chain_type_kwargs = {"prompt": create_prompt()}
81
+
82
+ qa = RetrievalQA.from_chain_type(llm=chat,
83
+ chain_type='stuff',
84
+ retriever=vector_db.as_retriever(),
85
+ chain_type_kwargs=chain_type_kwargs,
86
+ return_source_documents=True
87
+ )
88
+ return "Audio Processing completed ..."
89
+
90
+ def infer(question, history):
91
+ # res = []
92
+ # for human, ai in history[:-1]:
93
+ # pair = (human, ai)
94
+ # res.append(pair)
95
+
96
+ # chat_history = res
97
+
98
+ result = qa({"query": question})
99
+ matching_docs_score = vector_db.similarity_search_with_score(question)
100
+ logger.info("Matching Score :",matching_docs_score)
101
+ return result["result"]
102
+
103
+ def bot(history):
104
+ response = infer(history[-1][0], history)
105
+ history[-1][1] = ""
106
+
107
+ for character in response:
108
+ history[-1][1] += character
109
+ time.sleep(0.05)
110
+ yield history
111
+
112
+ def add_text(history, text):
113
+ history = history + [(text, None)]
114
+ return history, ""
115
+
116
+
117
+ def loading_file():
118
+ return "Loading..."
119
+
120
+
121
+ css="""
122
+ #col-container {max-width: 2048px; margin-left: auto; margin-right: auto;}
123
+ """
124
+
125
+ title = """
126
+ <div style="text-align: center;max-width: 2048px;">
127
+ <h1>Chat with Youtube Videos </h1>
128
+ <p style="text-align: center;">Upload a youtube link of any video-lecture/song/Research/Conference & ask Questions to chatbot with the tool.
129
+ <i> Tools uses State of the Art Models from HuggingFace/OpenAI so, make sure to add your key.</i>
130
+ </p>
131
+ </div>
132
+ """
133
+ with gr.Blocks(css=css) as demo:
134
+ with gr.Row():
135
+ with gr.Column(elem_id="col-container"):
136
+ gr.HTML(title)
137
+
138
+ with gr.Column():
139
+ with gr.Row():
140
+ LLM_option = gr.Dropdown(['HuggingFace','OpenAI'],label='Select HuggingFace/OpenAI')
141
+ API_key = gr.Textbox(label="Add API key", type="password",autofocus=True)
142
+ wav_model = gr.Dropdown(['small','medium','large'],label='Select Whisper model')
143
+
144
+ with gr.Group():
145
+ chatbot = gr.Chatbot(height=270)
146
+
147
+ with gr.Row():
148
+ question = gr.Textbox(label="Type your question !",lines=1).style(full_width=True)
149
+
150
+ with gr.Row():
151
+ submit_btn = gr.Button(value="Send message", variant="primary", scale = 1)
152
+ clean_chat_btn = gr.Button("Delete Chat")
153
+ with gr.Column():
154
+ with gr.Box():
155
+ audio_file = gr.File(label="Upload Audio File ", file_types=FILE_EXT, type="file")
156
+ with gr.Accordion(label='Advanced options', open=False):
157
+ max_new_tokens = gr.Slider(
158
+ label='Max new tokens',
159
+ minimum=2048,
160
+ maximum=MAX_NEW_TOKENS,
161
+ step=1,
162
+ value=DEFAULT_MAX_NEW_TOKENS,
163
+ )
164
+ temperature = gr.Slider(
165
+ label='Temperature',
166
+ minimum=0.1,
167
+ maximum=4.0,
168
+ step=0.1,
169
+ value=DEFAULT_TEMPERATURE,
170
+ )
171
+ with gr.Row():
172
+ langchain_status = gr.Textbox(label="Status", placeholder="", interactive = False)
173
+ load_audio = gr.Button("Upload Audio File",).style(full_width = False)
174
+ if audio_file:
175
+ load_audio.click(loading_file, None, langchain_status, queue=False)
176
+ load_audio.click(audio_processor, inputs=[audio_file,API_key,wav_model,LLM_option,temperature,max_new_tokens], outputs=[langchain_status], queue=False)
177
+
llm_ops.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ def get_openai_chat_model(API_key):
4
+ try:
5
+ from langchain.llms import OpenAI
6
+ except ImportError as err:
7
+ raise "{}, unable to load openAI. Please install openai and add OPENAIAPI_KEY"
8
+ os.environ["OPENAI_API_KEY"] = API_key
9
+ llm = OpenAI()
10
+ return llm
11
+
12
+ def get_hugging_face_model(model_id,API_key,temperature=0.1,max_tokens=4096):
13
+ try:
14
+ from langchain import HuggingFaceHub
15
+ except ImportError as err:
16
+ raise "{}, unable to load openAI. Please install openai and add OPENAIAPI_KEY"
17
+ chat_llm = HuggingFaceHub(huggingfacehub_api_token=API_key,
18
+ repo_id=model_id,
19
+ model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens})
20
+ return chat_llm
21
+
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ openai
2
+ tiktoken
3
+ chromadb
4
+ langchain
5
+ unstructured
6
+ unstructured[local-inference]
7
+ transformers
8
+ torch
9
+ faiss-cpu
10
+ sentence-transformers
11
+ youtube-transcript-api
12
+ whisper
whisper_app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch as th
3
+ import whisper
4
+ from whisper.audio import SAMPLE_RATE
5
+ from tenacity import retry, wait_random
6
+ import openai
7
+ import requests
8
+ import time
9
+ # os.environ['OPENAI_API_KEY'] = "sk-<API KEY>"
10
+
11
+ class WHISPERModel:
12
+ def __init__(self, model_name='small', device='cuda',openai_flag=False):
13
+ self.device = device
14
+ self.openai_flag = openai_flag
15
+ self.model = whisper.load_model(model_name, device=self.device)
16
+
17
+ def get_info(self, audio_data, conv_duration=30):
18
+ clip_audio = whisper.pad_or_trim(audio_data, length=SAMPLE_RATE * conv_duration)
19
+ result = self.model.transcribe(clip_audio)
20
+ return result['language']
21
+
22
+ def speech_to_text(self, audio_path):
23
+ self.logger.info("Reading url {}".format(audio_path))
24
+ text_data = dict()
25
+ audio_duration = 0
26
+ conv_language = ""
27
+ r = requests.get(audio_path)
28
+ if r.status_code == 200:
29
+ try:
30
+ audio = whisper.load_audio(audio_path)
31
+ conv_language = self.get_info(audio)
32
+ if conv_language !='en':
33
+ res = self.model.transcribe(audio,task='translate')
34
+ if self.openai_flag:
35
+ res['text'] = self.translate_text(res['text'], orginal_text=conv_language, convert_to='English')
36
+ else:
37
+ res = self.model.transcribe(audio)
38
+ audio_duration = audio.shape[0] / SAMPLE_RATE
39
+ text_data['text'] = res['text']
40
+ text_data['duration'] = audio_duration
41
+ text_data['language'] = conv_language
42
+ except IOError as err:
43
+ raise f"Issue in loading audio {audio_path}"
44
+ else:
45
+ raise("Unable to reach for URL {}".format(audio_path))
46
+ return text_data
47
+
48
+
49
+
50
+ @retry(wait=wait_random(min=5, max=10))
51
+ def translate_text(self, text, orginal_text='ar', convert_to='english'):
52
+ prompt = f'Translate the following {orginal_text} text to {convert_to}:\n\n{orginal_text}: ' + text + '\n{convert_to}:'
53
+ # Generate response using ChatGPT
54
+ response = openai.Completion.create(
55
+ engine='text-davinci-003',
56
+ prompt=prompt,
57
+ max_tokens=100,
58
+ n=1,
59
+ stop=None,
60
+ temperature=0.7
61
+ )
62
+ # Extract the translated English text from the response
63
+ translation = response.choices[0].text.strip()
64
+ return translation
65
+
66
+ if __name__ == '__main__':
67
+ url = "https://prypto-api.aswat.co/surveillance/recordings/5f53c28b-3504-4b8b-9db5-0c8b69a96233.mp3"
68
+ audio2text = WHISPERModel()
69
+ text = audio2text.speech_to_text(url)