File size: 1,945 Bytes
f3d785b
 
67cbf92
f3d785b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb80e0b
f3d785b
cb80e0b
f3d785b
 
 
 
 
 
 
 
 
 
 
cb80e0b
f3d785b
 
cb80e0b
 
 
 
f3d785b
67cbf92
f61238b
 
 
 
 
f3d785b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import os
import requests
import gradio as gr
from conversation import Conversation


class BaseModel:
    name: str
    endpoint: str
    namespace: str
    generation_params: dict

    def __init__(self, name, endpoint, namespace, generation_params):
        self.name = name
        self.endpoint = endpoint
        self.namespace = namespace
        self.generation_params = generation_params

    def generate_response(self, conversation, custom_generation_params=None):
        prompt = self._get_prompt(conversation)
        response = self._get_response(prompt, custom_generation_params)
        return response

    def _get_prompt(self, conversation: Conversation):
        prompt = "\n".join(
            [conversation.memory, conversation.prompt]
        ).strip()
        for message in conversation.messages:
            prompt += f"\n{message['from'].strip()}: {message['value'].strip()}"
        prompt += f"\n{conversation.bot_label}:"
        return prompt

    def _get_response(self, text, custom_generation_params):
        api = str(os.environ.get("API_BASE_PATH")).replace("\{\}", "{}")
        api = api.format(self.endpoint, self.namespace)
        parameters = self.generation_params
        if custom_generation_params is not None:
            parameters.update(custom_generation_params)
        payload = {'instances': [text], "parameters": parameters}
        resp = requests.post(api, json=payload, timeout=600)
        if resp.status_code != 200:
            raise gr.Error(f"Endpoint returned code: {resp.status_code}. "
                           f"Solution: "
                           f"1. Scale-to-Zero enabled, so please wait for some minutes and try again. "
                           f"2. Probably the response generated by the model is to big, try changing max_new_tokens. "
                           f"3. If nothing helps — report the problem.")
        return resp.json()["predictions"][0].strip()