File size: 2,040 Bytes
0a19530
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os

from langchain.chat_models import ChatCohere
from langchain.schema import AIMessage, HumanMessage


## cohere with connector
## cohere with internet

# https://python.langchain.com/docs/modules/data_connection/retrievers/
# https://python.langchain.com/docs/integrations/llms/cohere

from langchain.chat_models import ChatCohere
from langchain.retrievers import CohereRagRetriever
from langchain.schema.document import Document

from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate


from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.schema import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough


from langchain.prompts import (
ChatPromptTemplate,
MessagesPlaceholder,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
                                )
from dotenv import load_dotenv

from prompt import wikipedia_template, general_internet_template

load_dotenv()  # take environment variables from .env. 
# https://pypi.org/project/python-dotenv/

COHERE_API_KEY = os.getenv("COHERE_API_KEY")

def format_docs(docs):
    return "\n\n".join([d.page_content for d in docs])


def create_chain_from_template(template, retriever, model):
    prompt = PromptTemplate(template=template, input_variables=["query"])
    chain = (
        {"context": retriever | format_docs, "query": RunnablePassthrough()}
        | prompt
        | model
        | StrOutputParser()
    )
    return chain



if __name__ == "__main__":


    llm_model = ChatCohere(
        cohere_api_key=COHERE_API_KEY,
        )

    template = wikipedia_template
    prompt = PromptTemplate(template=template, input_variables=["query"])

    rag = CohereRagRetriever(llm=llm_model,)

    llm_chain = create_chain_from_template(
        template, 
        rag, 
        llm_model
        )
    
    sample_query = "What is Cellular Automata and who created it?"
    sample_output = llm_chain.invoke(sample_query)
    print(sample_output)