File size: 2,436 Bytes
9f341cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import json
import re
import requests
from messagers.message_outputer import OpenaiStreamOutputer
from utils.logger import logger
from utils.enver import enver


class MessageStreamer:
    MODEL_MAP = {
        "mixtral-8x7b": "mistralai/Mixtral-8x7B-Instruct-v0.1",
    }

    def __init__(self, model: str):
        self.model = model
        self.model_fullname = self.MODEL_MAP[model]

    def parse_line(self, line):
        line = line.decode("utf-8")
        line = re.sub(r"data:\s*", "", line)
        data = json.loads(line)
        content = data["token"]["text"]
        return content

    def chat(
        self,
        prompt: str = None,
        temperature: float = 0.01,
        max_new_tokens: int = 32000,
        stream: bool = True,
        yield_output: bool = False,
    ):
        # https://huggingface.co/docs/text-generation-inference/conceptual/streaming#streaming-with-curl
        self.request_url = (
            f"https://api-inference.huggingface.co/models/{self.model_fullname}"
        )
        self.message_outputer = OpenaiStreamOutputer()
        self.request_headers = {
            "Content-Type": "application/json",
        }
        # huggingface_hub/inference/_client.py: class InferenceClient > def text_generation()
        self.request_body = {
            "inputs": prompt,
            "parameters": {
                "temperature": temperature,
                "max_new_tokens": max_new_tokens,
                "return_full_text": False,
            },
            "stream": stream,
        }
        print(self.request_url)
        enver.set_envs(proxies=True)
        stream = requests.post(
            self.request_url,
            headers=self.request_headers,
            json=self.request_body,
            proxies=enver.requests_proxies,
            stream=stream,
        )
        print(stream.status_code)
        for line in stream.iter_lines():
            if not line:
                continue

            content = self.parse_line(line)

            if content.strip() == "</s>":
                content_type = "Finished"
                logger.mesg("\n[Finished]")
            else:
                content_type = "Completions"
                logger.mesg(content, end="")

            if yield_output:
                output = self.message_outputer.output(
                    content=content, content_type=content_type
                )
                yield output