bevelapi / main.py
BeveledCube's picture
Update main.py
781452b verified
raw
history blame
3.1 kB
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from pydantic import BaseModel
from fastapi import FastAPI
import os
from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoModelForCausalLM, AutoTokenizer
import torch
app = FastAPI()
name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
customGen = False
gpt2based = False
# microsoft/DialoGPT-small
# microsoft/DialoGPT-medium
# microsoft/DialoGPT-large
# mistralai/Mixtral-8x7B-Instruct-v0.1
# Load the Hugging Face GPT-2 model and tokenizer
model = AutoModelForCausalLM.from_pretrained(name)
tokenizer = AutoTokenizer.from_pretrained(name)
gpt2model = GPT2LMHeadModel.from_pretrained(name)
gpt2tokenizer = GPT2Tokenizer.from_pretrained(name)
class req(BaseModel):
prompt: str
length: int
@app.get("/")
def read_root():
return FileResponse(path="templates/index.html", media_type="text/html")
@app.post("/api")
def read_root(data: req):
print("Prompt:", data.prompt)
print("Length:", data.length)
if (name == "microsoft/DialoGPT-small" or name == "microsoft/DialoGPT-medium" or name == "microsoft/DialoGPT-large") and customGen == True:
# tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
# model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small")
step = 1
# encode the new user input, add the eos_token and return a tensor in Pytorch
new_user_input_ids = tokenizer.encode(data.prompt + tokenizer.eos_token, return_tensors='pt')
# append the new user input tokens to the chat history
bot_input_ids = torch.cat(new_user_input_ids, dim=-1) if step > 0 else new_user_input_ids
# generated a response while limiting the total chat history to 1000 tokens,
chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
generated_text = tokenizer.decode(chat_history_ids[:, :][0], skip_special_tokens=True)
answer_data = { "answer": generated_text }
print("Answer:", generated_text)
return answer_data
else:
if gpt2based == True:
input_text = data.prompt
# Tokenize the input text
input_ids = gpt2tokenizer.encode(input_text, return_tensors="pt")
# Generate output using the model
output_ids = gpt2model.generate(input_ids, max_length=data.length, num_beams=5, no_repeat_ngram_size=2)
generated_text = gpt2tokenizer.decode(output_ids[0], skip_special_tokens=True)
answer_data = { "answer": generated_text }
print("Answer:", generated_text)
return answer_data
else:
input_text = data.prompt
# Tokenize the input text
input_ids = tokenizer.encode(input_text, return_tensors="pt")
# Generate output using the model
output_ids = model.generate(input_ids, max_length=data.length, num_beams=5, no_repeat_ngram_size=2)
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
answer_data = { "answer": generated_text }
print("Answer:", generated_text)
return answer_data