File size: 2,892 Bytes
d223d29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15a65f9
 
 
d223d29
 
 
be9af02
d223d29
be9af02
d223d29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from torch import Tensor
from transformers import AutoTokenizer, AutoModel
from ctranslate2 import Translator
from typing import Union

from fastapi import FastAPI
from pydantic import BaseModel


def average_pool(last_hidden_states: Tensor,
                 attention_mask: Tensor) -> Tensor:
    last_hidden = last_hidden_states.masked_fill(
        ~attention_mask[..., None].bool(), 0.0)
    return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]


# text-ada replacement
embeddingTokenizer = AutoTokenizer.from_pretrained(
    './multilingual-e5-base')
embeddingModel = AutoModel.from_pretrained('./multilingual-e5-base')

# chatGpt replacement
inferenceTokenizer = AutoTokenizer.from_pretrained(
    "./fastchat-t5-3b-ct2")
inferenceTranslator = Translator(
    "./fastchat-t5-3b-ct2", compute_type="int8", device="cpu")


class EmbeddingRequest(BaseModel):
    input: Union[str, None] = None


class TokensCountRequest(BaseModel):
    input: Union[str, None] = None


class InferenceRequest(BaseModel):
    input: Union[str, None] = None
    max_length: Union[int, None] = 0


app = FastAPI()


@app.get("/")
async def root():
    return {"message": "Hello World"}


@app.post("/text-embedding")
async def text_embedding(request: EmbeddingRequest):
    input = request.input

    # Process the input data
    batch_dict = embeddingTokenizer([input], max_length=512,
                                    padding=True, truncation=True, return_tensors='pt')
    outputs = embeddingModel(**batch_dict)
    embeddings = average_pool(outputs.last_hidden_state,
                              batch_dict['attention_mask'])

    # create response
    return {
        'embedding': embeddings[0].tolist()
    }


@app.post('/inference')
async def inference(request: InferenceRequest):
    input_text = request.input
    max_length = 256
    try:
        max_length = int(request.max_length)
        max_length = min(1024, max_length)
    except:
        pass

    # process request
    input_tokens = inferenceTokenizer.convert_ids_to_tokens(
        inferenceTokenizer.encode(input_text))

    results = inferenceTranslator.translate_batch(
        [input_tokens], beam_size=1, max_input_length=0, max_decoding_length=max_length, num_hypotheses=1, repetition_penalty=1.3, sampling_topk=40, sampling_temperature=0.7, use_vmap=False)

    output_tokens = results[0].hypotheses[0]
    output_text = inferenceTokenizer.decode(
        inferenceTokenizer.convert_tokens_to_ids(output_tokens))

    # create response
    return {
        'generated_text': output_text
    }


@app.post('/tokens-count')
async def tokens_count(request: TokensCountRequest):
    input_text = request.input

    tokens = inferenceTokenizer.convert_ids_to_tokens(
        inferenceTokenizer.encode(input_text))

    # create response
    return {
        'tokens': tokens,
        'total': len(tokens)
    }