vasilee commited on
Commit
83dacf1
1 Parent(s): b618c9e

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +14 -78
main.py CHANGED
@@ -1,105 +1,41 @@
1
- from torch import Tensor
2
- from transformers import AutoTokenizer, AutoModel
3
- from ctranslate2 import Translator
4
  from typing import Union
5
-
6
  from fastapi import FastAPI
7
  from pydantic import BaseModel
8
-
9
-
10
- def average_pool(last_hidden_states: Tensor,
11
- attention_mask: Tensor) -> Tensor:
12
- last_hidden = last_hidden_states.masked_fill(
13
- ~attention_mask[..., None].bool(), 0.0)
14
- return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
15
-
16
-
17
- # text-ada replacement
18
- embeddingTokenizer = AutoTokenizer.from_pretrained(
19
- './multilingual-e5-base')
20
- embeddingModel = AutoModel.from_pretrained('./multilingual-e5-base')
21
-
22
- # chatGpt replacement
23
- inferenceTokenizer = AutoTokenizer.from_pretrained(
24
- "./flan-alpaca-gpt4-xl-ct2")
25
- inferenceTranslator = Translator(
26
- "./flan-alpaca-gpt4-xl-ct2", compute_type="int8", device="cpu")
27
-
28
-
29
- class EmbeddingRequest(BaseModel):
30
- input: Union[str, None] = None
31
-
32
-
33
- class TokensCountRequest(BaseModel):
34
- input: Union[str, None] = None
35
 
36
 
37
  class InferenceRequest(BaseModel):
38
  input: Union[str, None] = None
39
- max_length: Union[int, None] = 0
40
 
41
 
42
  app = FastAPI()
43
 
 
 
 
44
 
45
  @app.get("/")
46
  async def root():
47
  return {"message": "Hello World"}
48
 
49
 
50
- @app.post("/text-embedding")
51
- async def text_embedding(request: EmbeddingRequest):
52
- input = request.input
53
-
54
- # Process the input data
55
- batch_dict = embeddingTokenizer([input], max_length=512,
56
- padding=True, truncation=True, return_tensors='pt')
57
- outputs = embeddingModel(**batch_dict)
58
- embeddings = average_pool(outputs.last_hidden_state,
59
- batch_dict['attention_mask'])
60
-
61
- # create response
62
- return {
63
- 'embedding': embeddings[0].tolist()
64
- }
65
-
66
-
67
  @app.post('/inference')
68
  async def inference(request: InferenceRequest):
69
  input_text = request.input
70
- max_length = 256
71
  try:
72
- max_length = int(request.max_length)
73
- max_length = min(1024, max_length)
74
  except:
75
  pass
76
 
77
  # process request
78
- input_tokens = inferenceTokenizer.convert_ids_to_tokens(
79
- inferenceTokenizer.encode(input_text))
80
-
81
- results = inferenceTranslator.translate_batch(
82
- [input_tokens], beam_size=1, max_input_length=0, max_decoding_length=max_length, num_hypotheses=1, repetition_penalty=1.3, sampling_topk=40, sampling_temperature=0.7, use_vmap=False)
83
-
84
- output_tokens = results[0].hypotheses[0]
85
- output_text = inferenceTokenizer.decode(
86
- inferenceTokenizer.convert_tokens_to_ids(output_tokens), skip_special_tokens=True)
87
-
88
- # create response
89
- return {
90
- 'generated_text': output_text
91
- }
92
-
93
-
94
- @app.post('/tokens-count')
95
- async def tokens_count(request: TokensCountRequest):
96
- input_text = request.input
97
-
98
- tokens = inferenceTokenizer.convert_ids_to_tokens(
99
- inferenceTokenizer.encode(input_text))
100
 
101
  # create response
102
- return {
103
- 'tokens': tokens,
104
- 'total': len(tokens)
105
- }
 
 
 
 
1
  from typing import Union
 
2
  from fastapi import FastAPI
3
  from pydantic import BaseModel
4
+ from llama_cpp import Llama
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
 
7
  class InferenceRequest(BaseModel):
8
  input: Union[str, None] = None
9
+ max_tokens: Union[int, None] = 0
10
 
11
 
12
  app = FastAPI()
13
 
14
+ llm = Llama(model_path="./models/vicuna-7b-v1.5.Q4_K_M.gguf",
15
+ verbose=False, n_ctx=4096)
16
+
17
 
18
  @app.get("/")
19
  async def root():
20
  return {"message": "Hello World"}
21
 
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  @app.post('/inference')
24
  async def inference(request: InferenceRequest):
25
  input_text = request.input
26
+ max_tokens = 256
27
  try:
28
+ max_tokens = int(request.max_tokens)
 
29
  except:
30
  pass
31
 
32
  # process request
33
+ try:
34
+ result = llm(input_text, temperature=0.2,
35
+ top_k=5, max_tokens=max_tokens)
36
+ return result
37
+ except:
38
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  # create response
41
+ return {}