File size: 3,646 Bytes
0a19530
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import gradio as gr

from langchain import LLMChain
from langchain import PromptTemplate
from langchain.llms import Cohere

# from create_chain import chain as llm_chain
from create_chain import create_chain_from_template
from prompt import wikipedia_template, general_internet_template
from langchain.retrievers import CohereRagRetriever
from langchain.chat_models import ChatCohere

import os
from dotenv import load_dotenv



load_dotenv()  # take environment variables from .env. 
# https://pypi.org/project/python-dotenv/

COHERE_API_KEY = os.getenv("COHERE_API_KEY")


examples = [
    ["What is Cellular Automata and who created it?"],
    ["What is Cohere"],
    ["who is Katherine Johnson"],
]

def create_UI(llm_chain):
    with gr.Blocks() as demo:
    #     radio = gr.Radio(
    #     ["wikipedia only", "any website", "none"], label="What kind of essay would you like to write?", value="wikipedia only"
    # )
        radio = gr.Radio(
        ["wikipedia only", "any website", ], label="What kind of essay would you like to write?", value="wikipedia only"
    )

        
        chatbot = gr.Chatbot()
        msg = gr.Textbox(info="Enter your question here, press enter to submit query")
        clear = gr.Button("Clear")
        # submit_btn = gr.Button("Submit", variant="primary")

        gr.Examples(examples=examples, label="Examples", inputs=msg,)
        

        def user(user_message, history):
            return "", history + [[user_message, None]]

        def bot(history):
            print("Question: ", history[-1][0])
            bot_message = llm_chain.invoke(history[-1][0])

            bot_message = bot_message
            print("Response: ", bot_message)
            history[-1][1] = ""
            history[-1][1] += bot_message
            return history
        
        def change_textbox(choice):
            if choice == "wikipedia only":
                template = wikipedia_template
                llm_chain = create_chain_from_template(
                    template, 
                    rag, 
                    llm_model
                    )
                return llm_chain
            elif choice == "any website":
                template = general_internet_template
                llm_chain = create_chain_from_template(
                    template, 
                    rag, 
                    llm_model
                    )
                return llm_chain
            elif choice == "none":
                submit_btn = gr.Button("Submit", variant="primary")
                return gr.Textbox(lines=8, visible=True, value="Lorem ipsum dolor sit amet"), gr.Button("Submit", variant="primary")
            else:
                return gr.Textbox(visible=False), gr.Button(interactive=False)

        text = gr.Textbox(lines=2, interactive=True, show_copy_button=True)
        # radio.change(fn=change_textbox, inputs=radio, outputs=[text, submit_btn])
        radio.change(fn=change_textbox, inputs=radio, outputs=[text])
        msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(bot, chatbot, chatbot)
        clear.click(lambda: None, None, chatbot, queue=False)
    return demo


if __name__ == "__main__":
    template = wikipedia_template
    prompt = PromptTemplate(template=template, input_variables=["query"])
    

    llm_model = ChatCohere(
        cohere_api_key=COHERE_API_KEY,
        )

    rag = CohereRagRetriever(llm=llm_model,)


    llm_chain = create_chain_from_template(
        template, 
        rag, 
        llm_model
        )

    demo = create_UI(llm_chain)
    demo.queue()
    # demo.launch()
    demo.launch(share=True)
    # pass