#!/usr/bin/python3 # -*- coding: utf-8 -*- import argparse import json import os import random import sys import time from typing import List pwd = os.path.abspath(os.path.dirname(__file__)) sys.path.append(os.path.join(pwd, '../../')) import requests from project_settings import environment def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--api_key", default=environment.get("agent_x_api_key", default=None), type=str ) args = parser.parse_args() return args class AgentX(object): def __init__(self, api_key: str, agent_name: str = "NXLink智能助手", url_host: str = "https://api.agentx.so" ): self.api_key = api_key self.agent_name = agent_name self.url_host = url_host self.agent_id = self.get_agent_id() def __str__(self): result = "<{}; agent_name: {}; agent_id: {}; api_key: {}>".format( self.__class__.__name__, self.agent_name, self.agent_id, self.api_key) return result def get_agent_id(self): url = "{}/api/v1/access/agents".format(self.url_host) headers = { "accept": "*/*", "x-api-key": self.api_key } resp = requests.request( "GET", url=url, headers=headers, ) if resp.status_code != 200: print(resp.status_code) print(resp.text) exit(0) js = resp.json() result = None for e in js: if e["name"] == self.agent_name: result = e["_id"] if result is None: raise AssertionError("agent not found") return result def get_agent_config(self): url = "{}/api/v1/access/agents/{}".format(self.url_host, self.agent_id) headers = { "accept": "*/*", "x-api-key": self.api_key } resp = requests.request( "GET", url=url, headers=headers, ) js = resp.json() return js def get_conversation_list(self): url = "{}/api/v1/access/agents/{}/conversations".format(self.url_host, self.agent_id) headers = { "accept": "*/*", "x-api-key": self.api_key } resp = requests.request( "GET", url=url, headers=headers, ) js = resp.json() return js def post_message(self, message: str, conversation_id: str, context: int = 0): url = "{}/api/v1/access/conversations/{}/message".format(self.url_host, conversation_id) headers = { "accept": "*/*", "Content-type": "application/json", "x-api-key": self.api_key } data = { "message": message, "context": context, } resp = requests.request( "POST", url=url, headers=headers, data=json.dumps(data) ) if resp.status_code != 200: print(resp.status_code) print(resp.text) exit(0) js = resp.json() return js def post_message_by_sse(self, message: str, conversation_id: str, context: int = 0): url = "{}/api/v1/access/conversations/{}/messagesse".format(self.url_host, conversation_id) headers = { "accept": "*/*", "Content-type": "application/json", "x-api-key": self.api_key } data = { "message": message, "context": context, } resp = requests.request( "POST", url=url, headers=headers, data=json.dumps(data), stream=True ) # print(resp.headers) trace_id = resp.headers["x-trace-id"] if resp.status_code == 200: def generator(): result = "" buf = b"" for chunk in resp.iter_content(): buf += chunk try: chunk = buf.decode("utf-8") except UnicodeDecodeError: continue result += chunk buf = b"" yield chunk return generator(), trace_id else: print(resp.status_code) print(resp.headers["Content-Type"]) raise AssertionError def get_trace_by_message_id(self, message_id: str): url = "{}/api/v1/access/messages/{}/trace".format(self.url_host, message_id) headers = { "accept": "*/*", "x-api-key": self.api_key } resp = requests.request( "GET", url=url, headers=headers, ) js = resp.json() return js def get_trace_by_trace_id(self, trace_id: str): url = "{}/api/v1/access/traces/{}".format(self.url_host, trace_id) headers = { "accept": "*/*", "x-api-key": self.api_key } resp = requests.request( "GET", url=url, headers=headers, ) js = resp.json() return js def post_new_conversation_id(self): url = "{}/api/v1/access/agents/{}/conversations/new".format(self.url_host, self.agent_id) headers = { "accept": "*/*", "x-api-key": self.api_key } resp = requests.request( "POST", url=url, headers=headers, ) js = resp.json() conversation_id = js["_id"] return conversation_id def delete_conversation(self, conversation_id: str): url = "{}/api/v1/access/conversations/{}".format(self.url_host, conversation_id) headers = { "accept": "*/*", "Content-type": "application/json", "x-api-key": self.api_key } resp = requests.request( "DELETE", url=url, headers=headers, ) js = resp.json() return js def update_context(self, messages: List[dict], conversation_id: str): url = "{}/api/v1/access/conversations/{}/update-context".format(self.url_host, conversation_id) headers = { "accept": "*/*", "Content-type": "application/json", "x-api-key": self.api_key } data = { "messages": messages, } resp = requests.request( "PUT", url=url, headers=headers, data=json.dumps(data), ) js = resp.json() return js def question_answer(self, question: str, conversation_id: str = None, context: List[dict] = None, streaming: bool = False): if conversation_id is None: conversation_id = self.post_new_conversation_id() if context is not None: self.update_context(context, conversation_id) result = { "answer": None, "reference": None } try: if streaming: resp_stream, trace_id = self.post_message_by_sse(question, conversation_id, context=0 if context is None else 1) answer = "" for chunk in resp_stream: print(chunk, end="") answer += chunk print("\n") result["answer"] = answer # print(answer) # exit(0) # [{"title": "", "source": ""}, ...] trace = self.get_trace_by_trace_id(trace_id) if trace == "No trace": reference = "No trace" else: reference = list() for t in trace: reference.append((t["title"], t["source"])) result["reference"] = reference else: js = self.post_message(question, conversation_id, context=0 if context is None else 1) answer = js["text"] result["answer"] = answer message_id = js["_id"] trace = self.get_trace_by_message_id(message_id) # print(trace) if trace == "No trace": reference = "No trace" else: reference = list() for t in trace: reference.append((t["title"], t["source"])) result["reference"] = reference finally: self.delete_conversation(conversation_id) return result def main(): args = get_args() agent = AgentX( api_key=args.api_key, agent_name="Yutong Bus", ) print(agent) context = [ { "user": "你好" }, { "assistant": "你好,我们是宇通客车公司,有什么可以帮到您的吗?" }, { "user": "需要一辆55座客车。" }, { "assistant": "Which country will the bus be used in?" }, { "user": "你可以说中文吗。" }, { "assistant": "可以的,请问您需要在哪个国家使用客车?" }, ] question = "你好" time_begin = time.time() response = agent.question_answer(question, context=context, streaming=True) time_cost = time.time() - time_begin print(response) print("time cost: {}".format(time_cost)) return if __name__ == '__main__': main()