AlekseyKorshuk's picture
updates
f61238b
raw
history blame
1.95 kB
import os
import requests
import gradio as gr
from conversation import Conversation
class BaseModel:
name: str
endpoint: str
namespace: str
generation_params: dict
def __init__(self, name, endpoint, namespace, generation_params):
self.name = name
self.endpoint = endpoint
self.namespace = namespace
self.generation_params = generation_params
def generate_response(self, conversation, custom_generation_params=None):
prompt = self._get_prompt(conversation)
response = self._get_response(prompt, custom_generation_params)
return response
def _get_prompt(self, conversation: Conversation):
prompt = "\n".join(
[conversation.memory, conversation.prompt]
).strip()
for message in conversation.messages:
prompt += f"\n{message['from'].strip()}: {message['value'].strip()}"
prompt += f"\n{conversation.bot_label}:"
return prompt
def _get_response(self, text, custom_generation_params):
api = str(os.environ.get("API_BASE_PATH")).replace("\{\}", "{}")
api = api.format(self.endpoint, self.namespace)
parameters = self.generation_params
if custom_generation_params is not None:
parameters.update(custom_generation_params)
payload = {'instances': [text], "parameters": parameters}
resp = requests.post(api, json=payload, timeout=600)
if resp.status_code != 200:
raise gr.Error(f"Endpoint returned code: {resp.status_code}. "
f"Solution: "
f"1. Scale-to-Zero enabled, so please wait for some minutes and try again. "
f"2. Probably the response generated by the model is to big, try changing max_new_tokens. "
f"3. If nothing helps — report the problem.")
return resp.json()["predictions"][0].strip()