|
import json |
|
import os |
|
import subprocess |
|
import sys |
|
import time |
|
|
|
import aiohttp |
|
import requests |
|
|
|
from lagent.schema import AgentMessage |
|
|
|
|
|
class HTTPAgentClient: |
|
|
|
def __init__(self, host='127.0.0.1', port=8090, timeout=None): |
|
self.host = host |
|
self.port = port |
|
self.timeout = timeout |
|
|
|
@property |
|
def is_alive(self): |
|
try: |
|
resp = requests.get( |
|
f'http://{self.host}:{self.port}/health_check', |
|
timeout=self.timeout) |
|
return resp.status_code == 200 |
|
except: |
|
return False |
|
|
|
def __call__(self, *message, session_id: int = 0, **kwargs): |
|
response = requests.post( |
|
f'http://{self.host}:{self.port}/chat_completion', |
|
json={ |
|
'message': [ |
|
m if isinstance(m, str) else m.model_dump() |
|
for m in message |
|
], |
|
'session_id': session_id, |
|
**kwargs, |
|
}, |
|
headers={'Content-Type': 'application/json'}, |
|
timeout=self.timeout) |
|
resp = response.json() |
|
if response.status_code != 200: |
|
return resp |
|
return AgentMessage.model_validate(resp) |
|
|
|
def state_dict(self, session_id: int = 0): |
|
resp = requests.get( |
|
f'http://{self.host}:{self.port}/memory/{session_id}', |
|
timeout=self.timeout) |
|
return resp.json() |
|
|
|
|
|
class HTTPAgentServer(HTTPAgentClient): |
|
|
|
def __init__(self, gpu_id, config, host='127.0.0.1', port=8090): |
|
super().__init__(host, port) |
|
self.gpu_id = gpu_id |
|
self.config = config |
|
self.start_server() |
|
|
|
def start_server(self): |
|
|
|
env = os.environ.copy() |
|
env['CUDA_VISIBLE_DEVICES'] = self.gpu_id |
|
cmds = [ |
|
sys.executable, 'lagent/distributed/http_serve/app.py', '--host', |
|
self.host, '--port', |
|
str(self.port), '--config', |
|
json.dumps(self.config) |
|
] |
|
self.process = subprocess.Popen( |
|
cmds, |
|
env=env, |
|
stdout=subprocess.PIPE, |
|
stderr=subprocess.STDOUT, |
|
text=True) |
|
|
|
while True: |
|
output = self.process.stdout.readline() |
|
if not output: |
|
break |
|
sys.stdout.write(output) |
|
sys.stdout.flush() |
|
if 'Uvicorn running on' in output: |
|
break |
|
time.sleep(0.1) |
|
|
|
def shutdown(self): |
|
self.process.terminate() |
|
self.process.wait() |
|
|
|
|
|
class AsyncHTTPAgentMixin: |
|
|
|
async def __call__(self, *message, session_id: int = 0, **kwargs): |
|
async with aiohttp.ClientSession( |
|
timeout=aiohttp.ClientTimeout(self.timeout)) as session: |
|
async with session.post( |
|
f'http://{self.host}:{self.port}/chat_completion', |
|
json={ |
|
'message': [ |
|
m if isinstance(m, str) else m.model_dump() |
|
for m in message |
|
], |
|
'session_id': session_id, |
|
**kwargs, |
|
}, |
|
headers={'Content-Type': 'application/json'}, |
|
) as response: |
|
resp = await response.json() |
|
if response.status != 200: |
|
return resp |
|
return AgentMessage.model_validate(resp) |
|
|
|
|
|
class AsyncHTTPAgentClient(AsyncHTTPAgentMixin, HTTPAgentClient): |
|
pass |
|
|
|
|
|
class AsyncHTTPAgentServer(AsyncHTTPAgentMixin, HTTPAgentServer): |
|
pass |
|
|