File size: 1,675 Bytes
d248ced
 
 
 
f9220be
60605bd
c7edc1a
 
9bf2007
d248ced
 
7338a55
e5e2748
21e7dd1
 
 
fdc39d2
2963451
7338a55
 
 
 
 
 
 
 
 
91521af
70864b6
724ddd6
 
72231f4
724ddd6
6254e11
724ddd6
 
d248ced
1c519c8
d248ced
1c519c8
 
1ce79d1
 
91521af
724ddd6
 
 
 
 
 
 
 
1720d8c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from transformers import pipeline, set_seed
generator = pipeline('text-generation', model='gpt2')
set_seed(40)

from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse


tokenizer.pad_token_id = tokenizer.eos_token_id

app = FastAPI()

MODEL = None
TOKENIZER = None
# ?input=%22Name%203%20shows%22
origins = ['https://aiforall.netlify.app']

app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

@app.get("/")
def llama(input):
    # prompt = [{'role': 'user', 'content': ""+input}]
    # inputs = TOKENIZER.apply_chat_template( prompt, add_generation_prompt=True,     return_tensors='pt' )

    # tokens = MODEL.generate( inputs.to(MODEL.device), max_new_tokens=1024, temperature=0.3, do_sample=True)

    # tresponse = TOKENIZER.decode(tokens[0], skip_special_tokens=False)
    # print(tresponse)
    outputs = generator(input, max_length=100, num_return_sequences=5)
    text = ""
    for o in outputs:
        text = text + ' ' + o["generated_text"]
    response_message = {"message": text}
    json_response = jsonable_encoder(response_message)
    return JSONResponse(content=json_response)
 
# @app.on_event("startup")
# def init_model():
#     global MODEL
#     global TOKENIZER
#     if not MODEL:
#         print("loading model")
#         TOKENIZER = AutoTokenizer.from_pretrained('stabilityai/stablelm-zephyr-3b')
#         MODEL = AutoModelForCausalLM.from_pretrained('stabilityai/stablelm-zephyr-3b', device_map="auto")
#         print("loaded model")