textgeneration / question_paper.py
Yash Sachdeva
gpt2
d248ced
raw
history blame
1.68 kB
from transformers import pipeline, set_seed
generator = pipeline('text-generation', model='gpt2')
set_seed(40)
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
tokenizer.pad_token_id = tokenizer.eos_token_id
app = FastAPI()
MODEL = None
TOKENIZER = None
# ?input=%22Name%203%20shows%22
origins = ['https://aiforall.netlify.app']
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
def llama(input):
# prompt = [{'role': 'user', 'content': ""+input}]
# inputs = TOKENIZER.apply_chat_template( prompt, add_generation_prompt=True, return_tensors='pt' )
# tokens = MODEL.generate( inputs.to(MODEL.device), max_new_tokens=1024, temperature=0.3, do_sample=True)
# tresponse = TOKENIZER.decode(tokens[0], skip_special_tokens=False)
# print(tresponse)
outputs = generator(input, max_length=100, num_return_sequences=5)
text = ""
for o in outputs:
text = text + ' ' + o["generated_text"]
response_message = {"message": text}
json_response = jsonable_encoder(response_message)
return JSONResponse(content=json_response)
# @app.on_event("startup")
# def init_model():
# global MODEL
# global TOKENIZER
# if not MODEL:
# print("loading model")
# TOKENIZER = AutoTokenizer.from_pretrained('stabilityai/stablelm-zephyr-3b')
# MODEL = AutoModelForCausalLM.from_pretrained('stabilityai/stablelm-zephyr-3b', device_map="auto")
# print("loaded model")