File size: 3,233 Bytes
30596e4
 
 
 
3f39329
30596e4
 
 
 
c6cb298
 
30596e4
b5b988a
 
3f39329
 
 
 
 
 
 
 
 
b5b988a
 
 
 
 
30596e4
 
 
 
519bb78
b5b988a
30596e4
 
519bb78
49dc05d
 
 
 
 
519bb78
30596e4
 
 
 
 
 
 
3f39329
30596e4
3f39329
30596e4
 
 
 
 
 
 
 
 
 
 
b5b988a
 
 
 
c8fe20b
b5b988a
30596e4
519bb78
 
 
c8fe20b
 
 
49dc05d
 
 
c8fe20b
 
 
 
 
c6cb298
c8fe20b
c6cb298
 
 
 
c8fe20b
793cb2b
c6cb298
c8fe20b
 
c6cb298
 
 
fcbd089
793cb2b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# LangChain supports many other chat models. Here, we're using Ollama
from langchain_community.chat_models import ChatOllama
from langchain_core.prompts import ChatPromptTemplate
from langchain import hub
from langchain.agents import AgentExecutor
from langchain.agents.format_scratchpad import format_log_to_str
from langchain.agents.output_parsers import (
    ReActJsonSingleInputOutputParser,
)
# Import things that are needed generically
from typing import List, Dict
from langchain.tools.render import render_text_description
import os
import dotenv
from innovation_pathfinder_ai.structured_tools.structured_tools import (
    arxiv_search, get_arxiv_paper, google_search
)

# hacky and should be replaced with a database
from innovation_pathfinder_ai.source_container.container import (
    all_sources
)


dotenv.load_dotenv()

OLLMA_BASE_URL = os.getenv("OLLMA_BASE_URL")


# supports many more optional parameters. Hover on your `ChatOllama(...)`
# class to view the latest available supported parameters
llm = ChatOllama(
    model="mistral:instruct",
    base_url= OLLMA_BASE_URL
    )


tools = [
    arxiv_search,
    google_search,
    get_arxiv_paper,
    ]


prompt = hub.pull("hwchase17/react-json")
prompt = prompt.partial(
    tools=render_text_description(tools),
    tool_names=", ".join([t.name for t in tools]),
)


# define the agent
chat_model_with_stop = llm.bind(stop=["\nObservation"])
agent = (
    {
        "input": lambda x: x["input"],
        "agent_scratchpad": lambda x: format_log_to_str(x["intermediate_steps"]),
    }
    | prompt
    | chat_model_with_stop
    | ReActJsonSingleInputOutputParser()
)

# instantiate AgentExecutor
agent_executor = AgentExecutor(
    agent=agent, 
    tools=tools, 
    verbose=True,
    # handle_parsing_errors=True #prevents error
    )

    
if __name__ == "__main__":
    
    input = agent_executor.invoke(
        {
            "input": "How to generate videos from images using state of the art macchine learning models; Using the axriv retriever  " +
            "add the urls of the papers used in the final answer using the metadata from the retriever please do not use '`' " + 
            "please use the `download_arxiv_paper` tool  to download any axriv paper you find" + 
            "Please only use the tools provided to you"
            # f"Please prioritize the newest papers this is the current data {get_current_date()}"
        }
    )

    # input_1 = agent_executor.invoke(
    #     {
    #         "input": "I am looking for a text to 3d model; Using the axriv retriever  " +
    #         "add the urls of the papers used in the final answer using the metadata from the retriever"
    #         # f"Please prioritize the newest papers this is the current data {get_current_date()}"
    #     }
    # )
    
    # input_2 = agent_executor.invoke(
    #     {
    #         "input": "I am looking for a text to 3d model; Using the google search tool " +
    #         "add the urls in the final answer using the metadata from the retriever, also provid a summary of the searches"
    #         # f"Please prioritize the newest papers this is the current data {get_current_date()}"
    #     }
    # )

    x = 0 # for debugging purposes