ai / app.py
netman19731's picture
Update app.py
540d82e verified
from langchain_openai.chat_models import ChatOpenAI
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain.tools.render import format_tool_to_openai_function
from langgraph.prebuilt import ToolExecutor,ToolInvocation
from typing import TypedDict, Annotated, Sequence
import operator
from langchain_core.messages import BaseMessage,FunctionMessage,HumanMessage,AIMessage
from langchain_community.tools import ShellTool,tool
import json
import os
import gradio as gr
os.environ["LANGCHAIN_TRACING_V2"] ="True"
os.environ["LANGCHAIN_API_KEY"]="ls__54e16f70b2b0455aad0f2cbf47777d30"
os.environ["OPENAI_API_KEY"]="20a79668d6113e99b35fcd541c65bfeaec497b8262c111bd328ef5f1ad8c6335"
# os.environ["OPENAI_API_KEY"]="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"
os.environ["LANGCHAIN_ENDPOINT"]="https://api.smith.langchain.com"
os.environ["LANGCHAIN_PROJECT"]="default"
os.environ['TAVILY_API_KEY'] = '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'
class AgentState(TypedDict):
messages: Annotated[Sequence[BaseMessage], operator.add]
import time
import jwt
def generate_token(apikey: str, exp_seconds: int):
try:
id, secret = apikey.split(".")
except Exception as e:
raise Exception("invalid apikey", e)
payload = {
"api_key": id,
"exp": int(round(time.time() * 1000)) + exp_seconds * 1000,
"timestamp": int(round(time.time() * 1000)),
}
return jwt.encode(
payload,
secret,
algorithm="HS256",
headers={"alg": "HS256", "sign_type": "SIGN"},
)
from langchain_openai import ChatOpenAI
# from jwt import generate_token
def get_glm(temprature):
llm = ChatOpenAI(
model_name="glm-4",
openai_api_base="https://open.bigmodel.cn/api/paas/v4",
openai_api_key=generate_token("bdc66124ffee87e2cae1aff403831c29.IfV2i1fN822Bwj7X",10000),
streaming=False,
temperature=temprature
)
return llm
from langchain_core.prompts import ChatPromptTemplate
prompt = ChatPromptTemplate.from_messages([
("system", '''你是西游餐厅经理,你叫唐僧,能为顾客提供中餐服务;
你有三个员工,分别是:厨师八戒,侍者沙僧,收银悟空;
你需要根据顾客的需求,按照流程向员工下达指令,流程如下:
1.当顾客表达要点菜的意愿后,先判断是否属于中餐,如果不是,委婉的拒绝服务,如果是,执行下一步骤;
2.向厨师八戒下达指令,让八戒做菜,请顾客稍等;
3.判断菜是否做完,如果还没做完,继续等待;如果做完了,执行下一步骤;
4.向沙僧下达指令,让沙僧把菜端给顾客;请顾客品尝;
5.当顾客表达吃完了或者想结账的时候,向悟空下达指令,让悟空结账;
6.当结账完成后,向顾客表达感谢,并结束服务。
'''),
("assistant", "好的,我将严格遵守流程,并提供服务。")
])
@tool(return_direct=True)
def chushi(query: str)->str:
'''你是餐厅厨师八戒,能根据经理的指令,做出一道菜'''
input={"input":query},
return "厨师八戒:接到指令,开始做菜!\n...\n菜已做好!"
@tool
def shizhe(query: str)->str:
'''你是餐厅侍者沙僧,能根据经理的指令,把菜端到顾客面前'''
input={"input":query}
return "侍者沙僧:收到指令,开始送菜!\n...\n,菜已送到"
@tool
def shouyin(query: str)->str:
'''你是餐厅收银悟空,能根据经理的指令,为顾客结账'''
input={"input":query}
return "结账完成,欢迎下次光临"
tools=[chushi,shizhe,shouyin]
from langchain_community.tools.convert_to_openai import format_tool_to_openai_tool
model = get_glm(0.01).bind(tools=[format_tool_to_openai_tool(tool) for tool in tools])
tool_executor = ToolExecutor(tools)
def should_continue(state):
messages = state['messages']
last_message = messages[-1]
# If there is no function call, then we finish
if "tool_calls" not in last_message.additional_kwargs:
return "end"
# Otherwise if there is, we continue
else:
return "continue"
# Define the function that calls the model
def call_model(state):
# global history
messages = state['messages']
response = model.invoke(messages)
# history.append([messages, response])
# We return a list, because this will get added to the existing list
return {"messages": [response]}
# Define the function to execute tools
def call_tool(state):
messages = state['messages']
# Based on the continue condition
# we know the last message involves a function call
last_message = messages[-1]
# We construct an ToolInvocation from the function_call
action = ToolInvocation(
tool=last_message.additional_kwargs["tool_calls"][0]["function"]["name"],
tool_input=json.loads(last_message.additional_kwargs["tool_calls"][0]["function"]["arguments"]),
)
# We call the tool_executor and get back a response
response = tool_executor.invoke(action)
print(response)
# We use the response to create a FunctionMessage
function_message = HumanMessage(content=response)
# function_message = FunctionMessage(content=str(response), name=action.tool)
# We return a list, because this will get added to the existing list
return {"messages": [function_message]}
from langgraph.graph import StateGraph, END
# Define a new graph
workflow = StateGraph(AgentState)
# Define the two nodes we will cycle between
workflow.add_node("agent", call_model)
workflow.add_node("action", call_tool)
# Set the entrypoint as `agent`
# This means that this node is the first one called
workflow.set_entry_point("agent")
# We now add a conditional edge
workflow.add_conditional_edges(
# First, we define the start node. We use `agent`.
# This means these are the edges taken after the `agent` node is called.
"agent",
# Next, we pass in the function that will determine which node is called next.
should_continue,
# Finally we pass in a mapping.
# The keys are strings, and the values are other nodes.
# END is a special node marking that the graph should finish.
# What will happen is we will call `should_continue`, and then the output of that
# will be matched against the keys in this mapping.
# Based on which one it matches, that node will then be called.
{
# If `tools`, then we call the tool node.
"continue": "action",
# Otherwise we finish.
"end": END
}
)
# We now add a normal edge from `tools` to `agent`.
# This means that after `tools` is called, `agent` node is called next.
workflow.add_edge('action', 'agent')
# Finally, we compile it!
# This compiles it into a LangChain Runnable,
# meaning you can use it as you would any other runnable
app = workflow.compile()
async def predict(message,history):
history_langchain_format = [prompt.format()]
for human, ai in history:
history_langchain_format.append(HumanMessage(content=(human+"\n"),))
history_langchain_format.append(AIMessage(content=(ai+"\n"),))
history_langchain_format.append(HumanMessage(content=(message+'\n')))
que={"messages": history_langchain_format}
# que={"messages": [HumanMessage(content=message)]}
# que={"messages":[prompt.format(input=message)]}
res=app.invoke(que)
if res:
mess_list=res["messages"][2:]
print(mess_list)
res_str=""
for i in mess_list:
response=i.content
print(response)
res_str+=(response+'\n')
return(res_str)
# for j in range(len(response)):
# time.sleep(0.3)
# yield response[: j+1]
else:print("不好意思,出了一个小问题,请联系我的微信:13603634456")
demo = gr.ChatInterface(fn=predict, title="西游餐厅",description="西游餐厅开张了,我是经理唐僧,欢迎光临,您有什么需求,可以直接问我哦!",)
demo.launch()