from fastapi import FastAPI from pydantic import BaseModel from transformers import AutoModelForSeq2SeqLM, AutoTokenizer # Load the model and tokenizer model_name = "shahzaib201/AI_OEL" model = AutoModelForSeq2SeqLM.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) # Pydantic model for input validation class TextInput(BaseModel): text: str max_length: int = 150 # Initialize FastAPI app app = FastAPI() # Endpoint for text summarization @app.post("/summarize_text") async def summarize_text_endpoint(item: TextInput): # Tokenize the input text inputs = tokenizer(item.text, return_tensors="pt", max_length=1024, truncation=True) # Generate the summary summary_ids = model.generate(inputs.input_ids, max_length=item.max_length, num_beams=4, length_penalty=2.0, early_stopping=True) # Decode the generated summary summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) return {"summary": summary}