File size: 3,216 Bytes
a3f5902
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from langchain.document_loaders import YoutubeLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.llms import HuggingFaceHub
from langchain.chains import LLMChain
from dotenv import find_dotenv, load_dotenv
from prompts import CHAT_PROMPT
from youtube_transcript_api import NoTranscriptFound
import streamlit as st
import os


class YouTubeChatbot:

    def __init__(self):
        load_dotenv(find_dotenv())

        if (st.secrets.hugging_face_api_key is not None):
            os.environ.setdefault("HUGGINGFACEHUB_API_TOKEN",
                                  st.secrets.hugging_face_api_key)

        try:
            self.embeddings = HuggingFaceEmbeddings()
        except Exception as e:
            st.error("Failed to load the Hugging Face Embeddings model: " +
                     str(e))
            self.embeddings = None

        try:
            repo_id = "tiiuae/falcon-7b-instruct"
            self.falcon_llm = HuggingFaceHub(
                repo_id=repo_id, model_kwargs={"temperature": 0.1, "max_new_tokens": 500}
            )

        except Exception as e:
            st.error("Failed to load the Falcon LLM model: " + str(e))
            self.falcon_llm = None


    @st.cache_data
    def create_db_from_youtube_video_url(_self, video_url):
        st.info("Creating FAISS database from YouTube video.")
        loader = YoutubeLoader.from_youtube_url(video_url)
        try:
            transcript = loader.load()
        except NoTranscriptFound:
            st.error("No transcript found for the video.")
            return None

        text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000,
                                                       chunk_overlap=100)
        docs = text_splitter.split_documents(transcript)
        st.info("Number of documents: " + str(len(docs)))

        try:
            db = FAISS.from_documents(docs, _self.embeddings)
            st.text("Created FAISS database from documents.")
            return db
        except Exception as e:
            st.error("Failed to create FAISS database from documents: " +
                     str(e))
            return None

    @st.cache_data
    def get_response_from_query(_self, _db, query, k=4):
        if _db is None:
            st.error(
                "Database is not initialized. Please check the error messages."
            )
            return None

        if _self.falcon_llm is None:
            st.error(
                "Falcon LLM model is not loaded. Please check the error messages."
            )
            return None

        docs = _db.similarity_search(query, k=k)
        docs_page_content = " ".join([d.page_content for d in docs])

        try:
            chain = LLMChain(llm=_self.falcon_llm, prompt=CHAT_PROMPT)
            response = chain.run(
                question=query,
                docs=docs_page_content
            )
            response = response.replace("\n", "")
            return response
        except Exception as e:
            st.error("Failed to generate a response: " + str(e))
            return None