Spaces:
Runtime error
Runtime error
from fastapi import FastAPI | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.responses import FileResponse | |
from transformers import T5Tokenizer, T5ForCausalLM, Trainer, TrainingArguments | |
app = FastAPI() | |
# Initialize the tokenizer and model | |
tokenizer = T5Tokenizer.from_pretrained("t5-base") | |
model = T5ForCausalLM.from_pretrained("t5-base") | |
with open("cyberpunk_lore.txt", "r") as f: | |
dataset = f.read() | |
# Tokenize the dataset | |
input_ids = tokenizer.batch_encode_plus(dataset, return_tensors="pt")["input_ids"] | |
# Set up training arguments | |
training_args = TrainingArguments( | |
output_dir='./results', | |
overwrite_output_dir=True, | |
num_train_epochs=5, | |
per_device_train_batch_size=1, | |
save_steps=10_000, | |
save_total_limit=2, | |
) | |
# Create a Trainer | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=input_ids, | |
eval_dataset=input_ids | |
) | |
# Fine-tune the model | |
trainer.train() | |
# Create the inference pipeline | |
pipe_flan = pipeline("text2text-generation", model=model) | |
def t5(input): | |
output = pipe_flan(input) | |
return {"output": output[0]["generated_text"]} | |
app.mount("/", StaticFiles(directory="static", html=True), name="static") | |
def index() -> FileResponse: | |
return FileResponse(path="/app/static/index.html", media_type="text/html") |