Spaces:
Runtime error
Runtime error
File size: 2,900 Bytes
773bf01 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
from torch import Tensor
from transformers import AutoTokenizer, AutoModel
from ctranslate2 import Translator
from typing import Union
from fastapi import FastAPI
from pydantic import BaseModel
def average_pool(last_hidden_states: Tensor,
attention_mask: Tensor) -> Tensor:
last_hidden = last_hidden_states.masked_fill(
~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
# text-ada replacement
embeddingTokenizer = AutoTokenizer.from_pretrained(
'./multilingual-e5-base')
embeddingModel = AutoModel.from_pretrained('./multilingual-e5-base')
# chatGpt replacement
inferenceTokenizer = AutoTokenizer.from_pretrained(
"./ct2fast-flan-alpaca-xl")
inferenceTranslator = Translator(
"./ct2fast-flan-alpaca-xl", compute_type="int8", device="cpu")
class EmbeddingRequest(BaseModel):
input: Union[str, None] = None
class TokensCountRequest(BaseModel):
input: Union[str, None] = None
class InferenceRequest(BaseModel):
input: Union[str, None] = None
max_length: Union[int, None] = 0
app = FastAPI()
@app.get("/")
async def root():
return {"message": "Hello World"}
@app.post("/text-embedding")
async def text_embedding(request: EmbeddingRequest):
input = request.input
# Process the input data
batch_dict = embeddingTokenizer([input], max_length=512,
padding=True, truncation=True, return_tensors='pt')
outputs = embeddingModel(**batch_dict)
embeddings = average_pool(outputs.last_hidden_state,
batch_dict['attention_mask'])
# create response
return {
'embedding': embeddings[0].tolist()
}
@app.post('/inference')
async def inference(request: InferenceRequest):
input_text = request.input
max_length = 256
try:
max_length = int(request.max_length)
max_length = min(1024, max_length)
except:
pass
# process request
input_tokens = inferenceTokenizer.convert_ids_to_tokens(
inferenceTokenizer.encode(input_text))
results = inferenceTranslator.translate_batch(
[input_tokens], beam_size=1, max_input_length=0, max_decoding_length=max_length, num_hypotheses=1, repetition_penalty=1.3, sampling_topk=40, sampling_temperature=0.7, use_vmap=False)
output_tokens = results[0].hypotheses[0]
output_text = inferenceTokenizer.decode(
inferenceTokenizer.convert_tokens_to_ids(output_tokens))
# create response
return {
'generated_text': output_text
}
@app.post('/tokens-count')
async def tokens_count(request: TokensCountRequest):
input_text = request.input
tokens = inferenceTokenizer.convert_ids_to_tokens(
inferenceTokenizer.encode(input_text))
# create response
return {
'tokens': tokens,
'total': len(tokens)
}
|