Monster-LLMs / MonsterAPIClient.py
VikasQblocks's picture
Fix Client and serving models
ecc07c2
#MonsterAPIClient.py
"""
Monster API Python client to connect to LLM models on monsterapi
Base URL: https://api.monsterapi.ai/v1/generate/{model}
Available models:
-----------------
1. falcon-7b-instruct
2. falcon-40b-instruct
3. mpt-30B-instruct
4. mpt-7b-instruct
5. openllama-13b-base
6. llama2-7b-chat
"""
import os
import time
import logging
import requests
from requests_toolbelt.multipart.encoder import MultipartEncoder
from typing import Optional, Literal, Union, List, Dict
from pydantic import BaseModel, Field
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class InputModel1(BaseModel):
"""
Supports Following models: Falcon-40B-instruct, Falcon-7B-instruct, openllama-13b-base, llama2-7b-chat
prompt string Prompt is a textual instruction for the model to produce an output. Required
top_k integer Top-k sampling helps improve quality by removing the tail and making it less likely to go off topic. Optional
(Default: 40)
top_p float Top-p sampling helps generate more diverse and creative text by considering a broader range of tokens. Optional
(Default: 1.0)
temp float The temperature influences the randomness of the next token predictions. Optional
(Default: 0.98)
max_length integer The maximum length of the generated text. Optional
(Default: 256)
repetition_penalty float The model uses this penalty to discourage the repetition of tokens in the output. Optional
(Default: 1.2)
beam_size integer The beam size for beam search. A larger beam size results in better quality output, but slower generation times. Optional
(Default: 1)
"""
prompt: str
top_k: int = 40
top_p: float = Field(0.9, ge=0., le=1.)
temp: float = Field(0.98, ge=0., le=1.)
max_length: int = 256
repetition_penalty: float = 1.2
beam_size: int = 1
class InputModel2(BaseModel):
"""
Supports Following models: MPT-30B-instruct, MPT-7B-instruct
prompt: string Instruction is a textual command for the model to produce an output. Required
top_k integer Top-k sampling helps improve quality by removing the tail and making it less likely to go off topic. Optional
(Default: 40)
top_p float Top-p sampling helps generate more diverse and creative text by considering a broader range of tokens. Optional
Allowed Range: 0 - 1
(Default: 1.0)
temp float Temperature is a parameter that controls the randomness of the model's output. The higher the temperature, the more random the output. Optional
(Default: 0.98)
max_length integer Maximum length of the generated output. Optional
(Default: 256)
"""
prompt: str
top_k: int = 40
top_p: float = Field(0.9, ge=0., le=1.)
temp: float = Field(0.98, ge=0., le=1.)
max_length: int = 256
MODELS_TO_DATAMODEL = {
'falcon-7b-instruct': InputModel1,
'falcon-40b-instruct': InputModel1,
'mpt-30B-instruct': InputModel2,
'mpt-7b-instruct': InputModel2,
'openllama-13b-base': InputModel1,
'llama2-7b-chat': InputModel1
}
class MClient():
def __init__(self):
self.boundary = '---011000010111000001101001'
self.auth_token = os.environ.get('MONSTER_API_KEY')
self.headers = {
"accept": "application/json",
"content-type": f"multipart/form-data; boundary={self.boundary}",
'Authorization': 'Bearer ' + self.auth_token}
self.base_url = 'https://api.monsterapi.ai/v1'
self.models_to_data_model = MODELS_TO_DATAMODEL
self.mock = os.environ.get('MOCK_Runner', "False").lower() == "true"
def get_response(self, model:Literal['falcon-7b-instruct', 'falcon-40b-instruct', 'mpt-30B-instruct', 'mpt-7b-instruct', 'openllama-13b-base', 'llama2-7b-chat'],
data: dict):
if model not in self.models_to_data_model:
raise ValueError(f"Invalid model: {model}!")
dataModel = self.models_to_data_model[model](**data)
url = f"{self.base_url}/generate/{model}"
data = dataModel.dict()
logger.info(f"Calling Monster API with url: {url}, with payload: {data}")
# convert all values into string
for key, value in data.items():
data[key] = str(value)
multipart_data = MultipartEncoder(fields=data, boundary=self.boundary)
response = requests.post(url, headers=self.headers, data=multipart_data)
response.raise_for_status()
return response.json()
def get_status(self, process_id):
# /v1/status/{process_id}
url = f"{self.base_url}/status/{process_id}"
response = requests.get(url, headers=self.headers)
response.raise_for_status()
return response.json()
def wait_and_get_result(self, process_id, timeout=100):
start_time = time.time()
while True:
elapsed_time = time.time() - start_time
if elapsed_time >= timeout:
raise TimeoutError(f"Process {process_id} timed out after {timeout} seconds.")
status = self.get_status(process_id)
if status['status'].lower() == 'completed':
return status['result']
elif status['status'].lower() == 'failed':
raise RuntimeError(f"Process {process_id} failed!")
else:
if self.mock:
return 100 * "Mock Output!"
logger.info(f"Process {process_id} is still running, status is {status['status']}. Waiting ...")
time.sleep(0.01)
if __name__ == '__main__':
client = MClient()
response = client.get_response('falcon-7b-instruct', {"prompt": 'How to make a sandwich?'})
output = client.wait_and_get_result(response['process_id'])
print(output)