File size: 5,363 Bytes
4531c67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95f899c
4531c67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from langchain_community.llms import Ollama
from langchain_community.chat_models import ChatOllama

from langchain import hub

from agentops.langchain_callback_handler import LangchainCallbackHandler as AgentOpsLangchainCallbackHandler

from langchain.chains.conversation.memory import ConversationBufferWindowMemory

from tools.cve_avd_tool import CVESearchTool
from tools.misp_tool import MispTool
from tools.coder_tool import CoderTool
from tools.mitre_tool import MitreTool

from langchain.agents import initialize_agent, AgentType, load_tools
from langchain.evaluation import load_evaluator


from dotenv import load_dotenv
import os
import re

load_dotenv(override=True)


llm = Ollama(model="codestral", base_url=os.getenv('OLLAMA_HOST'), temperature=0.3, num_predict=8192, num_ctx=32768)
wrn = Ollama(model="openfc", base_url=os.getenv('OLLAMA_HOST'))
llama3 = Ollama(model="llama3", base_url=os.getenv('OLLAMA_HOST'), temperature=0.3)

command_r = Ollama(model="command-r", base_url=os.getenv('OLLAMA_HOST'), temperature=0.1, num_ctx=8192)
hermes_llama3 = Ollama(model="adrienbrault/nous-hermes2pro-llama3-8b:q4_K_M", base_url=os.getenv('OLLAMA_HOST'), temperature=0.3, num_ctx=32768)
yarn_mistral_128k = Ollama(model="yarn-mistral-modified", base_url=os.getenv('OLLAMA_HOST'), temperature=0.1, num_ctx=65536, system="""""")

chat_llm = ChatOllama(model="openhermes", base_url=os.getenv('OLLAMA_HOST'), num_predict=-1)


cve_search_tool = CVESearchTool().cvesearch
fetch_cve_tool = CVESearchTool().fetchcve
misp_search_tool =  MispTool().search
misp_search_by_date_tool = MispTool().search_by_date
misp_search_by_event_id_tool = MispTool().search_by_event_id
coder_tool = CoderTool().code_generation_tool

get_technique_by_id = MitreTool().get_technique_by_id
get_technique_by_name = MitreTool().get_technique_by_name
get_malware_by_name = MitreTool().get_malware_by_name
get_tactic_by_keyword = MitreTool().get_tactic_by_keyword

tools = [cve_search_tool, fetch_cve_tool, misp_search_tool, misp_search_by_date_tool, misp_search_by_event_id_tool, 
         coder_tool, get_technique_by_id, get_technique_by_name, get_malware_by_name, get_tactic_by_keyword]

# conversational agent memory
memory = ConversationBufferWindowMemory(
    memory_key='chat_history',
    k=4,
    return_messages=True
)

agentops_handler = AgentOpsLangchainCallbackHandler(api_key=os.getenv("AGENTOPS_API_KEY"), tags=['Langchain Example'])

#Error handling
def _handle_error(error) -> str:

    pattern = r'```(?!json)(.*?)```'
    match = re.search(pattern, str(error), re.DOTALL)
    if match: 
        return "The answer contained a code blob which caused the parsing to fail, i recovered the code blob. Just use it to answer the user question: " + match.group(1)
    else: 
        return llm.invoke(f"""Try to summarize and explain the following error into 1 short and consice sentence and give a small indication to correct the error: {error} """)


prompt = hub.pull("hwchase17/react-chat-json")
# create our agent
conversational_agent = initialize_agent(
    # agent="chat-conversational-react-description",
    agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,
    tools=tools,
    prompt=prompt,
    llm=llm,
    verbose=True,
    max_iterations=5,
    memory=memory,
    early_stopping_method='generate',
    # callbacks=[agentops_handler],
    handle_parsing_errors=_handle_error,
    return_intermediate_steps=False,
    max_execution_time=40,
)

evaluator = load_evaluator("trajectory", llm=chat_llm)


# conversational_agent.agent.llm_chain.prompt.messages[0].prompt.template = """
# 'Respond to the human as helpfully and accurately as possible. 
# You should use the tools available to you to help answer the question.
# Your final answer should be technical, well explained, and accurate.
# You have access to the following tools:\n\n\n\nUse a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).\n\nValid "action" values: "Final Answer" or \n\nProvide only ONE action per $JSON_BLOB, as shown:\n\n```\n{{\n  "action": $TOOL_NAME,\n  "action_input": $INPUT\n}}\n```\n\nFollow this format:\n\nQuestion: input question to answer\nThought: consider previous and subsequent steps\nAction:\n```\n$JSON_BLOB\n```\nObservation: action result\n... (repeat Thought/Action/Observation N times)\nThought: I know what to respond\nAction:\n```\n{{\n  "action": "Final Answer",\n  "action_input": "Final response to human"\n}}\n```\n\nBegin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.\nThought:'
# """

template = conversational_agent.agent.llm_chain.prompt.messages[0].prompt.template

conversational_agent.agent.llm_chain.prompt.messages[0].prompt.template = """You are a cyber security analyst, you role is to respond to the human queries in a technical way while providing detailed explanations when providing final answer.""" + template

def invoke(input_text):
    results = conversational_agent({"input":input_text})
    # evaluation_result = evaluator.evaluate_agent_trajectory(
    # prediction=results["output"],
    # input=results["input"],
    # agent_trajectory=results["intermediate_steps"],
    # )

    # print(evaluation_result)
    return results['output']