File size: 2,900 Bytes
773bf01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from torch import Tensor
from transformers import AutoTokenizer, AutoModel
from ctranslate2 import Translator
from typing import Union

from fastapi import FastAPI
from pydantic import BaseModel


def average_pool(last_hidden_states: Tensor,
                 attention_mask: Tensor) -> Tensor:
    last_hidden = last_hidden_states.masked_fill(
        ~attention_mask[..., None].bool(), 0.0)
    return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]


# text-ada replacement
embeddingTokenizer = AutoTokenizer.from_pretrained(
    './multilingual-e5-base')
embeddingModel = AutoModel.from_pretrained('./multilingual-e5-base')

# chatGpt replacement
inferenceTokenizer = AutoTokenizer.from_pretrained(
    "./ct2fast-flan-alpaca-xl")
inferenceTranslator = Translator(
    "./ct2fast-flan-alpaca-xl", compute_type="int8", device="cpu")


class EmbeddingRequest(BaseModel):
    input: Union[str, None] = None


class TokensCountRequest(BaseModel):
    input: Union[str, None] = None


class InferenceRequest(BaseModel):
    input: Union[str, None] = None
    max_length: Union[int, None] = 0


app = FastAPI()


@app.get("/")
async def root():
    return {"message": "Hello World"}


@app.post("/text-embedding")
async def text_embedding(request: EmbeddingRequest):
    input = request.input

    # Process the input data
    batch_dict = embeddingTokenizer([input], max_length=512,
                                    padding=True, truncation=True, return_tensors='pt')
    outputs = embeddingModel(**batch_dict)
    embeddings = average_pool(outputs.last_hidden_state,
                              batch_dict['attention_mask'])

    # create response
    return {
        'embedding': embeddings[0].tolist()
    }


@app.post('/inference')
async def inference(request: InferenceRequest):
    input_text = request.input
    max_length = 256
    try:
        max_length = int(request.max_length)
        max_length = min(1024, max_length)
    except:
        pass

    # process request
    input_tokens = inferenceTokenizer.convert_ids_to_tokens(
        inferenceTokenizer.encode(input_text))

    results = inferenceTranslator.translate_batch(
        [input_tokens], beam_size=1, max_input_length=0, max_decoding_length=max_length, num_hypotheses=1, repetition_penalty=1.3, sampling_topk=40, sampling_temperature=0.7, use_vmap=False)

    output_tokens = results[0].hypotheses[0]
    output_text = inferenceTokenizer.decode(
        inferenceTokenizer.convert_tokens_to_ids(output_tokens))

    # create response
    return {
        'generated_text': output_text
    }


@app.post('/tokens-count')
async def tokens_count(request: TokensCountRequest):
    input_text = request.input

    tokens = inferenceTokenizer.convert_ids_to_tokens(
        inferenceTokenizer.encode(input_text))

    # create response
    return {
        'tokens': tokens,
        'total': len(tokens)
    }