fastapi_t5 / main.py
streetyogi's picture
Update main.py
62a4b51
raw
history blame
1.34 kB
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from transformers import T5Tokenizer, T5ForCausalLM, Trainer, TrainingArguments
app = FastAPI()
# Initialize the tokenizer and model
tokenizer = T5Tokenizer.from_pretrained("t5-base")
model = T5ForCausalLM.from_pretrained("t5-base")
with open("cyberpunk_lore.txt", "r") as f:
dataset = f.read()
# Tokenize the dataset
input_ids = tokenizer.batch_encode_plus(dataset, return_tensors="pt")["input_ids"]
# Set up training arguments
training_args = TrainingArguments(
output_dir='./results',
overwrite_output_dir=True,
num_train_epochs=5,
per_device_train_batch_size=1,
save_steps=10_000,
save_total_limit=2,
)
# Create a Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=input_ids,
eval_dataset=input_ids
)
# Fine-tune the model
trainer.train()
# Create the inference pipeline
pipe_flan = pipeline("text2text-generation", model=model)
@app.get("/infer_t5")
def t5(input):
output = pipe_flan(input)
return {"output": output[0]["generated_text"]}
app.mount("/", StaticFiles(directory="static", html=True), name="static")
@app.get("/")
def index() -> FileResponse:
return FileResponse(path="/app/static/index.html", media_type="text/html")