textgeneration / question_paper.py
Yash Sachdeva
stableLM
fdc39d2
raw
history blame
966 Bytes
import transformers
import torch
from fastapi import FastAPI
from transformers import AutoModelForCausalLM, AutoTokenizer
app = FastAPI()
MODEL = None
TOKENIZER = None
# ?input=%22Name%203%20shows%22
@app.get("/")
def llama(input):
prompt = [{'role': 'user', 'content': input}]
inputs = TOKENIZER.apply_chat_template( prompt, add_generation_prompt=True, return_tensors='pt' )
tokens = MODEL.generate( inputs.to(MODEL.device), max_new_tokens=1024, temperature=0.3, do_sample=True)
tresponse = TOKENIZER.decode(tokens[0], skip_special_tokens=False)
print(tresponse)
return tresponse
@app.on_event("startup")
def init_model():
global MODEL
global TOKENIZER
if not MODEL:
print("loading model")
TOKENIZER = AutoTokenizer.from_pretrained('stabilityai/stablelm-zephyr-3b')
MODEL = AutoModelForCausalLM.from_pretrained('stabilityai/stablelm-zephyr-3b', device_map="auto")
print("loaded model")