RAG-Chatbot-Interface / interaction.py
Logeswaransr's picture
Update interaction.py
3995ebc verified
raw
history blame
2.63 kB
from gtts import gTTS
import base64
import os
from haystack import Document
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever
from haystack.components.builders import PromptBuilder
from haystack.components.generators.hugging_face_local import HuggingFaceLocalGenerator
from haystack.pipeline import Pipeline
def init_doc_store(path, files):
docs = []
for file in files:
with open(path + '/' + file, 'r') as f:
content = f.read()
docs.append(Document(content=content, meta={'name':file}))
document_store = InMemoryDocumentStore()
document_store.write_documents(docs)
return document_store
def define_components(document_store):
retriever = InMemoryBM25Retriever(document_store, top_k=3)
template = """
Given the following information, answer the question.
Context:
{% for document in documents %}
{{ document.content }}
{% endfor %}
Question: {{question}}
Answer:
"""
prompt_builder = PromptBuilder(template=template)
generator = HuggingFaceLocalGenerator(model="gpt2",
task="text-generation",
# device='cuda',
generation_kwargs={
"max_new_tokens": 100,
"temperature": 0.9,
})
generator.warm_up()
return retreiver, prompt_builder, generator
def define_pipeline(retreiver, prompt_builder, generator):
basic_rag_pipeline = Pipeline()
basic_rag_pipeline.add_component("retriever", retriever)
basic_rag_pipeline.add_component("prompt_builder", prompt_builder)
basic_rag_pipeline.add_component("llm", generator)
basic_rag_pipeline.connect("retriever", "prompt_builder.documents")
basic_rag_pipeline.connect("prompt_builder", "llm")
return basic_rag_pipeline
def generate_response(question, pipeline):
response = pipeline.run({'retriever':{"query":question}, 'prompt_builder':{'question':question}})
response = response['llm']['replies'][0]
return response
def audio_response(response):
audio_stream="response_audio.mp3"
tts = gTTS(response)
tts.save(audio_stream)
with open(audio_stream, 'rb') as file:
audio_data = file.read()
audio_base64 = base64.b64encode(audio_data).decode('utf-8')
audio_tag = f'<audio autoplay="true" src="data:audio/mp3;base64,{audio_base64}">'
return audio_tag